# Training VAE for Anomaly Detection

The goal is to train Variational Autoencoders for anomaly detection.

## Dataset

ECG Heartbeat Categorization Dataset: https://www.kaggle.com/datasets/shayanfazeli/heartbeat?resource=download


In [13]:
# Imports
import pytorch_lightning as pl
import pandas as pd
from pathlib import Path
from torch import Tensor, FloatTensor, nn
from pandarallel import pandarallel
from typing import Optional
from torch.utils.data import (
    Dataset,
    DataLoader,
    Subset,
    ConcatDataset,
)
import torch
import os
import numpy as np
from sklearn.model_selection import train_test_split
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import shutil
from torchinfo import summary
import torch.nn.functional as F
from datetime import timedelta
from tqdm import tqdm
from typing import List
from datetime import datetime
from torchmetrics import Accuracy, ConfusionMatrix, AUROC, ROC
import matplotlib.pyplot as plt
import random

In [14]:
# Initial setup for reproducibility
SEED_VALUE = 500
pl.seed_everything(SEED_VALUE)


# Constants
WINDOW_LENGTH = timedelta(minutes=10)
WINDOW_STRIDE = timedelta(minutes=10)
PH_THRESHOLD = 7.05
MAX_DATETIME_DIFFERENCE_CTG = timedelta(days=365)
CALCULATED_Z_SCORE_PARAMS = [
    1016201603,
    138.06500900606565,
    324618849082.29016,
    1016201462,
    21.75170609426872,
    606363058472.4089,
]
Z_SCORE_OFFSET = 2
MIN_MAX_PARAMS = [0, 240, 0, 127]

Global seed set to 500


In [15]:
# Initialize Parallel Apply
pandarallel.initialize(progress_bar=True)
tqdm.pandas()

INFO: Pandarallel will run on 16 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


In [39]:
# Load the test time series dataset
ecg_dataset = Path("../ecg/ptbdb_abnormal.csv")
df = pd.read_csv(ecg_dataset,header=None)
df.head()

(4046, 188)


In [17]:
# Load the test time series dataset
ecg_dataset_abnormal = Path("../ecg/ptbdb_abnormal.csv")
df_abnormal = pd.read_csv(ecg_dataset_abnormal,header=None)
df_abnormal.head()
ecg_row = df_abnormal.iloc[2]
print(ecg_row.tail(1).values[0])
ecg_row = ecg_row[:-1]
print(ecg_row)

1.0
0      1.000000
1      0.951613
2      0.923963
3      0.853303
4      0.791859
         ...   
182    0.000000
183    0.000000
184    0.000000
185    0.000000
186    0.000000
Name: 2, Length: 187, dtype: float64


In [18]:
# Visualize test time series
# Generic function to visualize CTGs
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from datetime import timedelta
from typing import List


def visualize_ecg(
    ecg: List,
    show_points: bool = False,
):
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    if show_points:
        mode = "lines+markers"
    else:
        mode = "lines"

    marker = dict(color="#FFFFFF", size=5, line=dict(color="black", width=1))
    ecg = go.Scatter(
        x=[*range(len(ecg))],
        y=ecg,
        name=f"ECG",
        mode=mode,
        marker=marker,
        line=dict(color="red", width=2),
    )
    fig.append_trace(ecg, row=1, col=1)
    return fig

In [19]:
def create_and_save_roc_curve(fpr, tpr, thresholds, filename: Path):
    """Plot and save the ROC curve."""

    fig, ax = plt.subplots()

    fpr = fpr.cpu().numpy()
    tpr = tpr.cpu().numpy()

    ax.plot(fpr, tpr, "o")

    # labels, title and ticks
    ax.set_xlabel("False Positive Labels (FPR)")
    ax.set_ylabel("True Positive Labels (TPR)")
    ax.set_title("ROC Curve")
    fig.savefig(filename)
    plt.clf()
    return filename

In [37]:
random_ecg = df.sample().values.flatten().tolist()
fig = visualize_ecg(random_ecg)
fig.show()

In [21]:
import math
from datetime import datetime
from typing import List


class ECGDataset(Dataset):
    """Class representing the ECG dataset."""

    def __init__(self, data_dir: Path):
        self.data_dir = data_dir
        self.normal_df = pd.read_csv(data_dir / "ptbdb_normal.csv", header=None)
        self.normal_last_row = self.normal_df.shape[0]-1
        self.abnormal_df = pd.read_csv(data_dir / "ptbdb_abnormal.csv", header=None)
        self.df = pd.concat([self.normal_df,self.abnormal_df])
        self.df = self.df.reset_index()
        self.df = self.df.drop('index', axis=1)
        self.df = self.df.rename(columns={int(f"{len(df.columns)-1}"):"label"})
        self._targets = torch.tensor(self.df["label"].values)

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        ecg_row = self.df.iloc[index]
        label = ecg_row.tail(1).values[0]
        ecg_row = ecg_row[:-1]
        ecg = ecg_row.values.flatten().tolist()
        ecg_tensor = torch.from_numpy(np.array([ecg])).float()
#         ecg_tensor = ecg_tensor.permute(1,0)
        return ecg_tensor, label

    @property
    def normal_idx(self) -> Tensor:
        return [*range(self.normal_last_row+1)]

    @property
    def abnormal_idx(self) -> Tensor:
        return [*range(self.normal_last_row+1,self.df.shape[0])]
    
    @property
    def targets(self) -> Tensor:
        return self._targets


In [22]:
ecg_dataset = ECGDataset(data_dir=Path("/home/harshit/ecg-anomaly-detection/ecg/"))
ecg , label = ecg_dataset.__getitem__(0)
print(ecg.shape)
print(ecg)
print(label)

torch.Size([1, 187])
tensor([[1.0000, 0.9003, 0.3586, 0.0515, 0.0466, 0.1268, 0.1333, 0.1191, 0.1106,
         0.1130, 0.1066, 0.1070, 0.1159, 0.1224, 0.1224, 0.1195, 0.1159, 0.1224,
         0.1260, 0.1337, 0.1349, 0.1426, 0.1511, 0.1584, 0.1637, 0.1738, 0.1888,
         0.2079, 0.2310, 0.2585, 0.2946, 0.3258, 0.3626, 0.3983, 0.4295, 0.4494,
         0.4510, 0.4190, 0.3728, 0.3104, 0.2500, 0.2042, 0.1690, 0.1475, 0.1305,
         0.1244, 0.1175, 0.1167, 0.1159, 0.1187, 0.1155, 0.1139, 0.1195, 0.1167,
         0.1228, 0.1207, 0.1167, 0.1228, 0.1264, 0.1317, 0.1418, 0.1394, 0.1451,
         0.1434, 0.1410, 0.1406, 0.1382, 0.1370, 0.1321, 0.1284, 0.1284, 0.1280,
         0.1252, 0.1224, 0.1171, 0.1126, 0.1130, 0.1276, 0.1653, 0.1795, 0.1613,
         0.1767, 0.1827, 0.1746, 0.1515, 0.1479, 0.1349, 0.1228, 0.1070, 0.0981,
         0.0944, 0.0891, 0.0891, 0.0887, 0.0908, 0.0859, 0.0859, 0.0891, 0.0843,
         0.0579, 0.0000, 0.1163, 0.3096, 0.8343, 0.9643, 0.5616, 0.0814, 0.0324,
       

In [23]:
class ECGDataModule(pl.LightningDataModule):
    """Datamodule class to load the CTG dataset."""

    def __init__(
        self,
        data_dir: Path = Path("../ecg"),
        batch_size: int = 32,
        split_seed: int = 50,
        num_workers: int = 1,
    ):
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    def setup(self, stage: Optional[str] = None) -> None:
        self.full_dataset = ECGDataset(self.hparams.data_dir,
        )
        self.normal_dataset = Subset(self.full_dataset, self.full_dataset.normal_idx)
        self.abnormal_dataset = Subset(self.full_dataset, self.full_dataset.abnormal_idx)
        
        non_test_indexes, test_indexes = train_test_split(
            np.arange(len(self.normal_dataset)),
            test_size=0.20,
            random_state=self.hparams.split_seed,
            shuffle=True,
        )
        
        self.data_test_normal = Subset(self.normal_dataset, test_indexes)
        #Extract val samples from remaining
        train_indexes, val_indexes = train_test_split(
            non_test_indexes,
            test_size=0.25,
            random_state=self.hparams.split_seed,
            shuffle=True,
        )

        #Train and val datasets just containing normal samples
        self.data_train = Subset(self.normal_dataset, train_indexes)
        
        #Single batch experiment
#         random_single_batch_indexes = random.sample(sorted(train_indexes), self.hparams.batch_size)
#         self.random_single_batch = Subset(self.normal_dataset, random_single_batch_indexes)
                
        self.data_val = Subset(self.normal_dataset, val_indexes)
        
        #Test dataset containing just abnormal samples
        self.data_test = ConcatDataset([self.data_test_normal,self.abnormal_dataset])

    def train_dataloader(self):
        return DataLoader(
            self.data_train,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.data_val,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
        )

    def test_dataloader(self):
        return DataLoader(
            self.data_test,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
        )

In [24]:
# Define the VAE model
class Encoder(nn.Module):
    def __init__(self, input_size=4096, hidden_size=1024, num_layers=2):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=False,
        )

    def forward(self, x):
        # x: tensor of shape (batch_size, seq_length, hidden_size)
        outputs, (hidden, cell) = self.lstm(x)
        return (hidden, cell)


class Decoder(nn.Module):
    def __init__(
        self, input_size=4096, hidden_size=1024, output_size=4096, num_layers=2
    ):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=False,
        )
        self.fc = nn.Linear(hidden_size, output_size)
        self.sig = nn.Sigmoid()

    def forward(self, x, hidden):
        # x: tensor of shape (batch_size, seq_length, hidden_size)
        output, (hidden, cell) = self.lstm(x, hidden)
        prediction = self.fc(output)
        prediction = self.sig(prediction)
        return prediction, (hidden, cell)


class ECG_LSTM_VAE(pl.LightningModule):
    """The class represneting the VAE model used for regression on CTGs."""

    def __init__(self, config):
        """
        input_size: int, batch_size x sequence_length x input_dim
        hidden_size: int, output size of LSTM AE
        latent_size: int, latent z-layer size
        num_lstm_layer: int, number of layers in LSTM
        """
        super().__init__()

        # dimensions
        self.lr = config["lr"]
        self.input_size = config["input_size"]
        self.hidden_size = config["hidden_size"]
        self.latent_size = config["latent_size"]
        self.num_layers = config["num_layers"]
        self.beta = config["beta"]
        self.save_images = config["save_images"]
        self.saved_images_path = config["saved_images_path"]
        self.anomaly_threshold = 0.01
        
        # lstm ae
        self.lstm_enc = Encoder(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
        )
        self.lstm_dec = Decoder(
            input_size=self.latent_size,
            output_size=self.input_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
        )

        self.fc21 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc22 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc3 = nn.Linear(self.latent_size, self.hidden_size)
        
        #ROC curve
        self.metrics = torch.nn.ModuleDict(
            {
                "test_auc": AUROC(num_classes=None, max_fpr=0.15),
                "test_roc": ROC(num_classes=None),
            }
        )

    def reparametize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std)

        z = mu + noise * std
        return z

    def forward(self, x):
        batch_size, seq_len, feature_dim = x.shape

        # encode input space to hidden space
        enc_hidden = self.lstm_enc(x)
        enc_h = enc_hidden[0].view(batch_size, self.hidden_size)

        # extract latent variable z(hidden space to latent space)
        mean = self.fc21(enc_h)
        logvar = self.fc22(enc_h)
        z = self.reparametize(mean, logvar)  # batch_size x latent_size

        # decode latent space to input space
        z = z.repeat(1, seq_len, 1)
        z = z.view(batch_size, seq_len, self.latent_size)
        reconstruct_output, hidden = self.lstm_dec(z, enc_hidden)

        x_hat = reconstruct_output
        return mean, logvar, x_hat

    def loss_function(self, *args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        pairwise = args[4]

        if pairwise:
            recons_loss = F.mse_loss(recons, input, reduction='none')
            kld_loss = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1)
        else:
            recons_loss = F.mse_loss(recons, input)
            kld_loss = torch.sum(
                -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
            )

        loss = recons_loss + (self.beta * kld_loss)

        return {
            "loss": loss,
            "Reconstruction_Loss": recons_loss,
            "KLD": -kld_loss,
        }
    

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=(self.lr))
        lr_scheduler = ReduceLROnPlateau(
            optimizer,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch, batch_idx):
        x, _ = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var)
        loss_dict = self.loss_function(x_out, x, mu, log_var)
        self.log(
            "train_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "kld_loss", loss_dict["KLD"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "recon_loss",
            loss_dict["Reconstruction_Loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        return loss_dict

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var)
        self.log(
            "val_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "val_kld_loss", loss_dict["KLD"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "val_recon_loss",
            loss_dict["Reconstruction_Loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        return {"reconstructed": x_out, "original": x, "loss": loss_dict["loss"]}

    def validation_epoch_end(self, outputs):
        if not self.save_images:
            return
        if not os.path.exists(self.saved_images_path):
            os.makedirs(self.saved_images_path)
        recon = torch.cat([tmp["reconstructed"] for tmp in outputs])
        orig = torch.cat([tmp["original"] for tmp in outputs])
        orig_saved_filename = self.save_signal_figure(orig[0], "original")
        self.logger.experiment.log_artifact(self.logger.run_id, orig_saved_filename)
        recon_saved_filename = self.save_signal_figure(recon[0], "reconstructed")
        self.logger.experiment.log_artifact(self.logger.run_id, recon_saved_filename)

    def save_signal_figure(self, signal_tensor, name):
        signal_tensor = signal_tensor.permute(1, 0)
        ecg = signal_tensor[0].cpu().detach().numpy()
        # toco_tensor = signal_tensor[1].cpu().detach().numpy()
        fig = visualize_ecg(ecg)
        saved_filename = (
            f"{self.saved_images_path}/{name}_epoch_{self.current_epoch+1}.png"
        )
        fig.write_image(saved_filename)
        return saved_filename

    def test_step(self, batch, batch_idx):
        ctgs, labels = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var)
        return {"mu": mu, "log_var": log_var, "loss": loss_dict["loss"]}
    
    def test_epoch_end(self, batch, batch_idx):
        ctgs, labels = batch
        mu, log_var, output = self(ctgs)
        return {"mu": mu, "log_var": log_var, "output": output}

In [25]:
# Define the CNN AE model
class ECGAE(pl.LightningModule):
    """The class represneting the VAE model used for regression on CTGs."""

    def __init__(self, config):
        """
        input_size: int, batch_size x sequence_length x input_dim
        hidden_size: int, output size of LSTM AE
        latent_size: int, latent z-layer size
        num_lstm_layer: int, number of layers in LSTM
        """
        super().__init__()

        # dimensions
        self.lr = config["lr"]
        self.save_images = config["save_images"]
        self.saved_images_path = config["saved_images_path"]
        self.batch_size = config["batch_size"]
        
        self.encoder = nn.Sequential(
            nn.Linear(188, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True), 
            nn.Linear(16, 8))

        self.decoder = nn.Sequential(
            nn.Linear(8, 16),
            nn.ReLU(True),
            nn.Linear(16, 32),
            nn.ReLU(True),
            nn.Linear(32, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True),
            nn.Linear(128,188),
            nn.Sigmoid())
        

    def forward(self, x):
        batch_size, seq_len, feature_dim = x.shape
        x = x.view(batch_size,seq_len)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat


    def loss_function(self, *args, **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        
        batch_size, seq_len, feature_dim = input.shape
        input = input.view(batch_size,seq_len)
        
        loss = F.mse_loss(recons, input)
        return {
            "loss": loss,
        }

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=(self.lr))
        lr_scheduler = ReduceLROnPlateau(
            optimizer,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch, batch_idx):
        x, _ = batch
        x_out = self.forward(x)
        loss_dict = self.loss_function(x_out,x)
        self.log(
            "train_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
        )
        return loss_dict

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        x_out = self.forward(x)
        loss_dict = self.loss_function(x_out,x)
        self.log(
            "val_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
        )
        return {"reconstructed": x_out, "original": x.squeeze(), "loss": loss_dict["loss"]}

    def validation_epoch_end(self, outputs):
        if not self.save_images:
            return
        if not os.path.exists(self.saved_images_path):
            os.makedirs(self.saved_images_path)
        recon = torch.cat([tmp["reconstructed"] for tmp in outputs])
        orig = torch.cat([tmp["original"] for tmp in outputs])
        if self.current_epoch==0:
            orig_saved_filename = self.save_signal_figure(orig[0], "original")
            self.logger.experiment.log_artifact(self.logger.run_id, orig_saved_filename)
        recon_saved_filename = self.save_signal_figure(recon[0], "reconstructed")
        self.logger.experiment.log_artifact(self.logger.run_id, recon_saved_filename)

    def save_signal_figure(self, signal_tensor, name):
        ecg = signal_tensor.cpu().detach().numpy()
        # toco_tensor = signal_tensor[1].cpu().detach().numpy()
        fig = visualize_ecg(ecg)
        saved_filename = (
            f"{self.saved_images_path}/{name}_epoch_{self.current_epoch+1}.png"
        )
        fig.write_image(saved_filename)
        return saved_filename

    def test_step(self, batch, batch_idx):
        ctgs, labels = batch
        mu, log_var, output = self(ctgs)
        return {"mu": mu, "log_var": log_var, "output": output}

In [26]:
class ECG_VAE(pl.LightningModule):
    """The class represneting the VAE model used for regression on CTGs."""

    def __init__(self, config):
        """
        input_size: int, batch_size x sequence_length x input_dim
        hidden_size: int, output size of LSTM AE
        latent_size: int, latent z-layer size
        num_lstm_layer: int, number of layers in LSTM
        """
        super().__init__()

        # dimensions
        self.lr = config["lr"]
        self.input_size = config["input_size"]
        self.hidden_size = config["hidden_size"]
        self.latent_size = config["latent_size"]
        self.num_layers = config["num_layers"]
        self.beta = config["beta"]
        self.save_images = config["save_images"]
        self.saved_images_path = config["saved_images_path"]
        
        # normal vae
        self.encoder = nn.Sequential(
            nn.Linear(188, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True),
            nn.Linear(16, self.hidden_size))

        self.decoder = nn.Sequential(
            nn.Linear(self.hidden_size, 16),
            nn.ReLU(True),
            nn.Linear(16, 32),
            nn.ReLU(True), 
            nn.Linear(32, 64),
            nn.ReLU(True), 
            nn.Linear(64, 128),
            nn.ReLU(True), 
            nn.Linear(128,188),
            nn.Tanh())

        self.fc21 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc22 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc3 = nn.Linear(self.latent_size, self.hidden_size)

    def reparametize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std)

        z = mu + noise * std
        return z

    def forward(self, x):
        batch_size, seq_len, feature_dim = x.shape
        x = x.squeeze()
        
        # encode input space to hidden space
        enc = self.encoder(x)

        # extract latent variable z(hidden space to latent space)
        mean = self.fc21(enc)
        logvar = self.fc22(enc)
        z = self.reparametize(mean, logvar)  # batch_size x latent_size

        # decode latent space to input space
        z = self.fc3(z)
        reconstruct_output = self.decoder(z)

        x_hat = reconstruct_output
        return mean, logvar, x_hat

    def loss_function(self, *args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        
        batch_size, seq_len, feature_dim = input.shape
        input = input.view(batch_size,seq_len)

        recons_loss = F.mse_loss(recons, input)

        kld_loss = torch.mean(
            -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
        )

        loss = recons_loss + (self.beta * kld_loss)

        return {
            "loss": loss,
            "Reconstruction_Loss": recons_loss,
            "KLD": -kld_loss,
        }

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=(self.lr))
        lr_scheduler = ReduceLROnPlateau(
            optimizer,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch, batch_idx):
        x, _ = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var)
        loss_dict = self.loss_function(x_out, x, mu, log_var)
        self.log(
            "train_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "kld_loss", loss_dict["KLD"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "recon_loss",
            loss_dict["Reconstruction_Loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        return loss_dict

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var)
        self.log(
            "val_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "val_kld_loss", loss_dict["KLD"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "val_recon_loss",
            loss_dict["Reconstruction_Loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        return {"reconstructed": x_out, "original": x.squeeze(), "loss": loss_dict["loss"]}

    def validation_epoch_end(self, outputs):
        if not self.save_images:
            return
        if not os.path.exists(self.saved_images_path):
            os.makedirs(self.saved_images_path)
        recon = torch.cat([tmp["reconstructed"] for tmp in outputs])
        orig = torch.cat([tmp["original"] for tmp in outputs])
        if self.current_epoch==0:
            orig_saved_filename = self.save_signal_figure(orig[0], "original")
            self.logger.experiment.log_artifact(self.logger.run_id, orig_saved_filename)
        recon_saved_filename = self.save_signal_figure(recon[0], "reconstructed")
        self.logger.experiment.log_artifact(self.logger.run_id, recon_saved_filename)

    def save_signal_figure(self, signal_tensor, name):
        ecg = signal_tensor.cpu().detach().numpy()
        # toco_tensor = signal_tensor[1].cpu().detach().numpy()
        fig = visualize_ecg(ecg)
        saved_filename = (
            f"{self.saved_images_path}/{name}_epoch_{self.current_epoch+1}.png"
        )
        fig.write_image(saved_filename)
        return saved_filename

    def test_step(self, batch, batch_idx):
        ctgs, labels = batch
        mu, log_var, output = self(ctgs)
        return {"mu": mu, "log_var": log_var, "output": output}

In [27]:
class ECG_CNN_VAE(pl.LightningModule):
    """The class represneting the VAE model used for regression on CTGs."""

    def __init__(self, config):
        """
        input_size: int, batch_size x sequence_length x input_dim
        hidden_size: int, output size of LSTM AE
        latent_size: int, latent z-layer size
        num_lstm_layer: int, number of layers in LSTM
        """
        super().__init__()

        # dimensions
        self.lr = config["lr"]
        self.input_size = config["input_size"]
        self.hidden_size = config["hidden_size"]
        self.latent_size = config["latent_size"]
        self.num_layers = config["num_layers"]
        self.beta = config["beta"]
        self.save_images = config["save_images"]
        self.saved_images_path = config["saved_images_path"]
        self.batch_size = config["batch_size"]
        
        #ROC curve
        self.metrics = torch.nn.ModuleDict(
            {
                "test_auc": AUROC(task='binary', num_classes=None, max_fpr=0.15),
                "test_roc": ROC(task='binary', num_classes=None),
                "test_roc_kld": ROC(task='binary', num_classes=None),
            }
        )
        
        modules = []
        hidden_dims = [8, 16, 32, 64, 128]
        latent_dim = self.latent_size

        # Build Encoder
        in_channels = 1
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv1d(in_channels=in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*6, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*6, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 6)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose1d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm1d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose1d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm1d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv1d(hidden_dims[-1], out_channels= 1,
                                      kernel_size= 6, padding= 0),
                            nn.Tanh())

    def reparametize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std)

        z = mu + noise * std
        return z

    def forward(self, x):
        batch_size, seq_len, feature_dim = x.shape
        
        # encode input space to hidden space
        enc = self.encoder(x)
        self.log(
            "enc_weight", self.encoder[4][0].weight[0][0][0] ,on_step=True, on_epoch=True, prog_bar=True
        )
        enc = torch.flatten(enc, start_dim=1)

        
        # extract latent variable z(hidden space to latent space)
        mean = self.fc_mu(enc)
        logvar = self.fc_var(enc)

        self.log(
            "mean_sample_0", mean.data[0][0] ,on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "logvar_sample_0", logvar.data[0][0] ,on_step=True, on_epoch=True, prog_bar=True
        )
        

        z = self.reparametize(mean, logvar)  # batch_size x latent_size

        # decode latent space to input space
        z = self.decoder_input(z)
        self.log(
            "dec_input_weight", self.decoder_input.weight[0][0],on_step=True, on_epoch=True, prog_bar=True
        )
        
        z = z.view(-1,128,6)
        out = self.decoder(z)
        
        self.log(
            "dec_weight", self.decoder[0][0].weight[0][0][0],on_step=True, on_epoch=True, prog_bar=True
        )
        
        reconstruct_output = self.final_layer(out)

        x_hat = reconstruct_output
        return mean, logvar, x_hat

    def loss_function(self, *args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        test = args[4]


        if test:
            loss = nn.BCELoss(reduction='none')
#             recons_loss = torch.mean(torch.squeeze(loss(recons, input)),dim=1)
            recons_loss = torch.squeeze(torch.mean(F.mse_loss(recons, input, reduction='none'),dim=2))
            kld_loss = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1)
        else:
#             recons_loss = F.binary_cross_entropy(recons, input, reduction='sum')
            recons_loss = torch.sum(torch.squeeze(torch.sum(F.mse_loss(recons, input, reduction='none'),dim=2)))
            kld_loss = torch.mean(
                -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
            )
            
        #To account for minibatches
        kld_weight = input.shape[0]/self.batch_size
        
        ## KL Cyclic Annealing
        beta = (self.current_epoch % 10)/10
        if test:
            beta = 1
            
        #Final loss
        loss = recons_loss + (self.beta * kld_weight * kld_loss)

            
        return {
            "loss": loss,
            "Reconstruction_Loss": recons_loss,
            "KLD": -kld_loss,
        }

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=(self.lr))
        lr_scheduler = ReduceLROnPlateau(
            optimizer,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": lr_scheduler,
            "monitor": "val_loss",
        }

    def training_step(self, batch, batch_idx):
        x, _ = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var, False)
        loss_dict = self.loss_function(x_out, x, mu, log_var, False)
        self.log(
            "train_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "kld_loss", loss_dict["KLD"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "recon_loss",
            loss_dict["Reconstruction_Loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        return loss_dict

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var, False)
        self.log(
            "val_loss", loss_dict["loss"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "val_kld_loss", loss_dict["KLD"], on_step=True, on_epoch=True, prog_bar=True
        )
        self.log(
            "val_recon_loss",
            loss_dict["Reconstruction_Loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )
        return {"reconstructed": x_out, "original": x, "loss": loss_dict["loss"]}

    def validation_epoch_end(self, outputs):
        if not self.save_images:
            return
        if not os.path.exists(self.saved_images_path):
            os.makedirs(self.saved_images_path)
        recon = torch.cat([tmp["reconstructed"] for tmp in outputs])
        orig = torch.cat([tmp["original"] for tmp in outputs])
        if self.current_epoch==0:
            orig_saved_filename = self.save_signal_figure(orig[200], "original")
            self.logger.experiment.log_artifact(self.logger.run_id, orig_saved_filename)
        recon_saved_filename = self.save_signal_figure(recon[200], "reconstructed")
        self.logger.experiment.log_artifact(self.logger.run_id, recon_saved_filename)

    def save_signal_figure(self, signal_tensor, name):
        ecg = signal_tensor[0].cpu().detach().numpy()
        fig = visualize_ecg(ecg)
        saved_filename = (
            f"{self.saved_images_path}/{name}_epoch_{self.current_epoch+1}.png"
        )
        fig.write_image(saved_filename)
        return saved_filename

    def test_step(self, batch, batch_idx):
        ctgs, labels = batch
        mu, log_var, output = self(ctgs)
        return {"mu": mu, "log_var": log_var, "output": output}

    def test_step(self, batch, batch_idx):
        x, labels = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var, True)
        loss = loss_dict["loss"]
        return {"reconstructed": x_out, "original": x, "labels": labels,"loss": loss_dict["loss"], 
                "recon_loss": loss_dict["Reconstruction_Loss"], "kld_loss": loss_dict["KLD"]}
    
    def test_epoch_end(self, outputs):
        if not self.save_images:
            return
        if not os.path.exists(self.saved_images_path):
            os.makedirs(self.saved_images_path)
        recon = torch.cat([tmp["reconstructed"] for tmp in outputs])
        orig = torch.cat([tmp["original"] for tmp in outputs])
        labels = torch.cat([tmp["labels"] for tmp in outputs])
        loss = torch.cat([tmp["loss"] for tmp in outputs])
        
        recon_loss = torch.cat([tmp["recon_loss"] for tmp in outputs])
        recon_loss = (recon_loss - recon_loss.min())/(recon_loss.max()-recon_loss.min())
        
        
        labels = labels.type(torch.int64)
        
        
        self.metrics["test_auc"](recon_loss, labels)
        self.metrics["test_roc"](recon_loss, labels)
        
        fpr, tpr, thresholds = self.metrics["test_roc"].compute()
        
        self.log(
            "test_auc", self.metrics["test_auc"](recon_loss, labels), prog_bar=True
        )
        
        plot_filepath = Path(f"test_roc.png")
        saved_filename = create_and_save_roc_curve(fpr, tpr, thresholds, plot_filepath)
        self.logger.experiment.log_artifact(self.logger.run_id, saved_filename)
        self.logger.experiment.log_artifact(self.logger.run_id, saved_filename)
        plot_filepath.unlink()
        
#         fpr, tpr, thresholds = self.metrics["test_roc_kld"].compute()
#         plot_filepath = Path(f"test_roc_kld.png")
#         saved_filename = create_and_save_roc_curve(fpr, tpr, thresholds, plot_filepath)
#         self.logger.experiment.log_artifact(self.logger.run_id, saved_filename)
#         self.logger.experiment.log_artifact(self.logger.run_id, saved_filename)
#         plot_filepath.unlink()
        
    

In [28]:
def train_model_with_hyperparams(config, num_epochs=20):
    """Train the CTGCNN model given the hyperparameters as config and the number of epochs.
    The tune flag is used to specify if the training is standalone or part of a hyperparameter optimization process.
    The k_fold flag is used to specify if the training should be done as K fold cross validation."""

    # Clear out the previous models
    models_folder = Path(f"../models/")

    shutil.rmtree(models_folder)
    models_folder.mkdir(exist_ok=True, parents=True)

    mlflow_runs_folder = Path("../mlruns")
    mlflow_runs_folder.mkdir(exist_ok=True)

    callbacks = []

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor="val_loss",
        dirpath=models_folder,
        filename="vae",
        save_top_k=1,
        mode="min",
    )
    
    early_stopping_callback = EarlyStopping(monitor="val_loss", mode="min", patience=10)

    callbacks.append(checkpoint_callback)
#     callbacks.append(early_stopping_callback)

    mlf_logger = MLFlowLogger(
        experiment_name="ecg_vae",
        tracking_uri=mlflow_runs_folder.absolute().as_uri(),
    )
    ecg_datamodule = ECGDataModule(
        data_dir=Path("/home/harshit/ecg-anomaly-detection/ecg/"),
        split_seed=SEED_VALUE,
        num_workers=16,
        batch_size=config["batch_size"],
    )
    ecg_datamodule.setup()
    ecg_model = ECG_CNN_VAE(config)
    trainer = pl.Trainer(
        logger=mlf_logger,
        max_epochs=num_epochs,
        callbacks=callbacks,
        num_sanity_val_steps=0,
        accelerator="gpu",
        devices=1,
#         gradient_clip_val=0.5,
#         gradient_clip_algorithm="value",
    )
    trainer.fit(model=ecg_model, datamodule=ecg_datamodule)
    trainer.test(model=ecg_model, datamodule=ecg_datamodule)
    
    

In [29]:
# Model Summary
config = {
    "beta": 0.1,
    "lr": 1e-4,
    "save_images": True,
    "saved_images_path": "plots",
    "input_size": 1,
    "hidden_size": 4,
    "latent_size": 4,
    "num_layers": 1,
    "batch_size": 32,
}
model = ECG_CNN_VAE(config)
summary(model, input_size=(32, 1, 187))


You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`



Layer (type:depth-idx)                   Output Shape              Param #
ECG_CNN_VAE                              [32, 4]                   --
├─Sequential: 1-1                        [32, 128, 6]              --
│    └─Sequential: 2-1                   [32, 8, 94]               --
│    │    └─Conv1d: 3-1                  [32, 8, 94]               32
│    │    └─BatchNorm1d: 3-2             [32, 8, 94]               16
│    │    └─LeakyReLU: 3-3               [32, 8, 94]               --
│    └─Sequential: 2-2                   [32, 16, 47]              --
│    │    └─Conv1d: 3-4                  [32, 16, 47]              400
│    │    └─BatchNorm1d: 3-5             [32, 16, 47]              32
│    │    └─LeakyReLU: 3-6               [32, 16, 47]              --
│    └─Sequential: 2-3                   [32, 32, 24]              --
│    │    └─Conv1d: 3-7                  [32, 32, 24]              1,568
│    │    └─BatchNorm1d: 3-8             [32, 32, 24]              64
│    │    └

In [39]:
# Try training the model with some hyperparameters.
config = {
    "beta": 0.1,
    "lr": 1e-3,
    "save_images": True,
    "saved_images_path": "plots",
    "input_size": 1,
    "hidden_size": 8,
    "latent_size": 8,
    "num_layers": 1,
    "batch_size": 32,
}
train_model_with_hyperparams(config=config, num_epochs=100)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type       | Params
---------------------------------------------
0 | metrics       | ModuleDict | 0     
1 | encoder       | Sequential | 33.4 K
2 | fc_mu         | Linear     | 6.2 K 
3 | fc_var        | Linear     | 6.2 K 
4 | decoder_input | Linear     | 6.9 K 
5 | decoder       | Sequential | 33.0 K
6 | final_layer   | Sequential | 265   
---------------------------------------------
85.9 K    Trainable params
0         Non-trainable params
85.9 K    Total params
0.344     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Experiment with name ecg_vae not found. Creating it.


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Testing: 0it [00:00, ?it/s]

<Figure size 640x480 with 0 Axes>

In [30]:
# See the accuracy of the model

def test_latest_model(config, num_epochs):
    mlflow_runs_folder = Path("../mlruns")
    mlflow_runs_folder.mkdir(exist_ok=True)
    
    mlf_logger = MLFlowLogger(
        experiment_name="ecg_vae",
        tracking_uri=mlflow_runs_folder.absolute().as_uri(),
    )
    ecg_datamodule = ECGDataModule(
        data_dir=Path("/home/harshit/ecg-anomaly-detection/ecg/"),
        split_seed=SEED_VALUE,
        num_workers=16,
        batch_size=config["batch_size"],
    )
    ecg_datamodule.setup()
    ecg_model = ECG_CNN_VAE(config)
    trainer = pl.Trainer(
        logger=mlf_logger,
        max_epochs=num_epochs,
        num_sanity_val_steps=0,
        accelerator="gpu",
        devices=1,
#         gradient_clip_val=0.5,
#         gradient_clip_algorithm="value",
    )
    trainer.test(model=ecg_model, datamodule=ecg_datamodule, ckpt_path="../models/vae.ckpt")
    
config = {
    "beta": 0.1,
    "lr": 1e-3,
    "save_images": True,
    "saved_images_path": "plots",
    "input_size": 1,
    "hidden_size": 8,
    "latent_size": 8,
    "num_layers": 1,
    "batch_size": 32,
}
test_latest_model(config, num_epochs = 1)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at ../models/vae.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from checkpoint at ../models/vae.ckpt


Testing: 0it [00:00, ?it/s]

<Figure size 640x480 with 0 Axes>

In [32]:
# Predict on a series of models
from typing import List

config = {
    "beta": 0.1,
    "lr": 1e-3,
    "save_images": True,
    "saved_images_path": "plots",
    "input_size": 1,
    "hidden_size": 8,
    "latent_size": 8,
    "num_layers": 1,
    "batch_size": 32,
}

ecg_model = ECG_CNN_VAE(config)
model = ECG_CNN_VAE.load_from_checkpoint("../models/vae.ckpt", config=config)
model.eval()


def visualize_long_sample(samples: List, predictions: List, means: List, sigmas: List, show_points: bool = False):
    
    single_ecg_sample, _ = samples[0]
    
    ecgs = [ ecg for ecg ,_ in samples]
    ecgs = [item for sublist in ecgs for item in sublist]
    ecgs = [item for sublist in ecgs for item in sublist]
        
    labels = [ np.ones(len(single_ecg_sample.data[0])) * label for _, label in samples]
    labels = [ item for sublist in labels for item in sublist ]
    
    preds = [ np.ones(len(single_ecg_sample.data[0])) * pred for pred in predictions]
    preds = [ item for sublist in preds for item in sublist ]

    
    mus = []
    for sample_means in means:
        sample_mu = []
        for mean in sample_means:
            sample_mu.append(np.ones(len(single_ecg_sample.data[0])) * mean)
        mus.append(sample_mu)
        
    stds = []
    for sample_sigmas in sigmas:
        sample_std = []
        for sigma in sample_sigmas:
            sample_std.append(np.ones(len(single_ecg_sample.data[0])) * sigma)
        stds.append(sample_std)
 
    
    fig = make_subplots(rows=11, cols=1, shared_xaxes=True)
    
    if show_points:
        mode = "lines+markers"
    else:
        mode = "lines"

    marker = dict(color="#FFFFFF", size=5, line=dict(color="black", width=1))
    ecg = go.Scatter(
        x=[*range(len(ecgs))],
        y=ecgs,
        name=f"ECG",
        mode=mode,
        marker=marker,
        line=dict(color="red", width=2),
    )
    pred = go.Scatter(
        x=[*range(len(labels))],
        y=preds,
        name=f"Prediction",
        mode=mode,
        marker=marker,
        line=dict(color="blue", width=2),
    )
    label = go.Scatter(
        x=[*range(len(labels))],
        y=labels,
        name=f"Label",
        mode=mode,
        marker=marker,
        line=dict(color="orange", width=2),
    )
    
    fig.append_trace(ecg, row=1, col=1)
    fig.append_trace(pred, row=2, col=1)
    fig.append_trace(label, row=3, col=1)
    
    for distribution in np.arange(0,8):
        
        dist_mean = [mu[distribution] for mu in mus]
        dist_mean = [ item for sublist in dist_mean for item in sublist ]
        
        dist_std = [std[distribution] for std in stds]
        dist_std = [ item for sublist in dist_std for item in sublist ]
        
        dist_mean_upper = [(mean+abs(std)) for mean, std in zip(dist_mean,dist_std)]
        dist_mean_lower = [(mean-abs(std)) for mean, std in zip(dist_mean,dist_std)]
        
        mean_trace = go.Scatter(
            x=[*range(len(labels))],
            y=dist_mean,
            name=f"Latent_Var_{distribution+1}",
            mode=mode,
            marker=marker,
            line=dict(color="black", width=2),
        )
        mean_trace_upper = go.Scatter(
            x=[*range(len(labels))],
            y=dist_mean_upper,
            name=f"Latent_Var_{distribution+1}_upper",
            mode=mode,
            marker=marker,
            line=dict(color="firebrick", width=2, dash='dash'),
        )
        mean_trace_lower = go.Scatter(
            x=[*range(len(labels))],
            y=dist_mean_lower,
            name=f"Latent_Var_{distribution+1}_lower",
            mode=mode,
            marker=marker,
            line=dict(color="darkgreen", width=2, dash='dash'),
        )
        fig.append_trace(mean_trace, row=3+distribution+1, col=1)
        fig.append_trace(mean_trace_upper, row=3+distribution+1, col=1)
        fig.append_trace(mean_trace_lower, row=3+distribution+1, col=1)

    fig = fig.update_layout(width=4000,height=2200)
        
    return fig
    
    

def generate_long_data_sample(no_of_segments: int = 20, no_of_abnormal_segments: int = 6):
    if no_of_abnormal_segments > no_of_segments:
        raise ValueError("No of abnormal segments must be less than the total no of segments.")
        
    no_of_normal_segments = no_of_segments - no_of_abnormal_segments
    
    full_dataset = ECGDataset(data_dir=Path("/home/harshit/ecg-anomaly-detection/ecg"))
    
    normal_dataset = Subset(full_dataset, full_dataset.normal_idx)
    abnormal_dataset = Subset(full_dataset, full_dataset.abnormal_idx)

    normal_indexes = random.sample(sorted(full_dataset.normal_idx), no_of_normal_segments)
    abnormal_indexes = random.sample(sorted(full_dataset.abnormal_idx), no_of_abnormal_segments)
    
    normal_samples = [ full_dataset.__getitem__(normal_index) for normal_index in normal_indexes]
    abnormal_samples = [ full_dataset.__getitem__(abnormal_index) for abnormal_index in abnormal_indexes]
    
    samples = normal_samples + abnormal_samples

    random.shuffle(samples)
    
    return samples

def get_prediction_for_ecg(ecg):
    ecg = ecg.view(1,1,187)
    mean, logvar, x_hat = model(ecg)
    loss_dict = model.loss_function(x_hat, ecg, mean, logvar, True)
    sigma = torch.abs(torch.exp(0.5 * logvar))
    
    return loss_dict["Reconstruction_Loss"], mean, sigma
    
    
samples = generate_long_data_sample()
outputs = [get_prediction_for_ecg(sample) for sample, _ in samples]

predictions = [recon_loss.detach().cpu().numpy() for recon_loss, _, _ in outputs]
means = [torch.squeeze(mean).detach().cpu().numpy() for _, mean, _ in outputs]
sigmas = [torch.squeeze(sigma).detach().cpu().numpy() for _, sigma, _ in outputs]

visualize_long_sample(samples=samples, predictions=predictions, means=means, sigmas=sigmas)



You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`

