# Training VAE CNN
The goal is to train Variational Autoencoder containing Convolutional layers, that takes in valid sequences from the dataset.

### Dataset Checkpoint Input
- extracted-segments

## Input Parameters
* WINDOW_LENGTH: The length of the sliding window moved over the valid input segments to extract training data.
* STRIDE: The stride of the sliding window.
* PH_THRESHOLD: The threshold below which a CTG is considered to be associated with an abnormal birth.

In [None]:
# Imports
import pytorch_lightning as pl
import pandas as pd
from pathlib import Path
from torch import Tensor, nn
from pandarallel import pandarallel
from typing import Optional
from torch.utils.data import (
    Dataset,
    DataLoader,
    Subset,
)
import torch
import os
import numpy as np
from sklearn.model_selection import train_test_split
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import shutil
from torchinfo import summary
import torch.nn.functional as F
from datetime import timedelta, datetime
from tqdm import tqdm
from typing import List
import random
import math
from collections import OrderedDict

In [None]:
# Initial setup for reproducibility
SEED_VALUE = 500
pl.seed_everything(SEED_VALUE)

# Constants
PH_THRESHOLD = 7.05
MAX_DATETIME_DIFFERENCE_CTG = timedelta(days=270)

# Initialize Parallel Apply
pandarallel.initialize(progress_bar=True)
tqdm.pandas()

# Visualization function for CTG
import sys

sys.path.append("../")
from ctg_common import visualize_ctg

In [None]:
data_dir = Path("../ctg-data/final-ctg-dataset")
expanded_outcome = pd.read_parquet(data_dir / "outcome_expanded.parquet")
expanded_outcome.head(15)

In [None]:
class CTGDataset(Dataset):
    """Class representing the CTG dataset."""

    def __init__(
        self,
        data_dir: Path,
    ):
        self.data_dir = data_dir
        outcomes = pd.read_parquet(data_dir / "outcome_expanded.parquet")
        outcomes = outcomes[outcomes["no_of_points"] > 2200]
        outcomes = outcomes[outcomes["no_of_points"] <= 2400]
        self.outcomes = outcomes
        self.outcomes["label"] = (self.outcomes["ns_art_ph"] < PH_THRESHOLD).astype(int)
        self.outcomes = self.outcomes[self.outcomes["ns_art_ph"] >= PH_THRESHOLD]
        self._targets = torch.tensor(self.outcomes["label"].values)

    def load_relevant_segments(self, row):
        identifier: str = row["identifier"]
        seq_start = row["start"]
        seq_end = row["end"]
        datetime_of_birth = row["datetime_of_birth"]
        segment_number = row["segment_number"]

        ctg_segment = pd.read_parquet(
            self.data_dir
            / "ctgs"
            / "filtered_segments"
            / f"{identifier}_{segment_number}.parquet"
        )

        # Select seq window, resample and impute
        seq_ctg = ctg_segment[seq_start:seq_end]
        if len(seq_ctg) > 2400:
            seq_ctg = seq_ctg.head(2400)
        if len(seq_ctg) < 2400:
            adj_seq_start_actual = seq_ctg.index.min()
            adj_seq_end_actual = seq_ctg.index.max()
            if seq_start < adj_seq_start_actual:
                seq_ctg = pd.concat([pd.DataFrame(index=[seq_start]), seq_ctg])
            if seq_end > adj_seq_end_actual:
                seq_ctg = pd.concat([seq_ctg, seq_ctg, pd.DataFrame(index=[seq_end])])
            seq_ctg = seq_ctg.resample("0.25S").agg({"FHR1": np.mean, "TOCO": np.mean})
            seq_ctg["FHR1"] = seq_ctg["FHR1"].interpolate(method="linear")
            seq_ctg["TOCO"] = seq_ctg["TOCO"].interpolate(method="linear")

        FHR_segment = seq_ctg["FHR1"].values.tolist()
        TOCO_segment = seq_ctg["TOCO"].values.tolist()

        time_segment: List[datetime] = seq_ctg.index.tolist()
        time_segment = [
            (datetime_of_birth - index_time).total_seconds()
            / MAX_DATETIME_DIFFERENCE_CTG.total_seconds()
            for index_time in time_segment
        ]
        if len(FHR_segment) != 2400:
            raise ValueError(
                f"FHR segment has length {len(FHR_segment)}, expected length 2400."
            )

        signal = torch.from_numpy(np.array([FHR_segment])).float()

        return signal

    def __len__(self):
        return self.outcomes.shape[0]

    def __getitem__(self, index):
        outcome = self.outcomes.iloc[index]
        ctg_tensor = self.load_relevant_segments(outcome)
        #         ctg_tensor = ctg_tensor.permute(1, 0)
        label: int = outcome["label"]
        return ctg_tensor, label

    @property
    def targets(self) -> Tensor:
        return self._targets

In [None]:
class CTGDataModule(pl.LightningDataModule):
    """Datamodule class to load the CTG dataset."""

    def __init__(
        self,
        data_dir: Path = Path("../ctg-data/final-ctg-dataset/"),
        batch_size: int = 32,
        split_seed: int = 50,
        num_workers: int = 1,
    ):
        super().__init__()

        self.save_hyperparameters(logger=False)

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

    def setup(self, stage: Optional[str] = None) -> None:
        self.full_dataset = CTGDataset(
            self.hparams.data_dir,
        )
        train_idx, test_idx = train_test_split(
            np.arange(len(self.full_dataset)),
            test_size=0.1,
            random_state=self.hparams.split_seed,
            shuffle=True,
            stratify=self.full_dataset.targets,
        )
        self.train_dataset = Subset(self.full_dataset, train_idx)
        self.data_test = Subset(self.full_dataset, test_idx)
        self.train_targets = self.full_dataset.targets[train_idx]
        self.test_targets = self.full_dataset.targets[test_idx]

        train_indexes, val_indexes = train_test_split(
            np.arange(len(self.train_dataset)),
            test_size=0.25,
            random_state=self.hparams.split_seed,
            shuffle=True,
            stratify=self.train_targets,
        )

        # Smaller dataset experiment
        smaller_train_dataset_indexes = random.sample(
            sorted(train_indexes), self.hparams.batch_size * 160
        )
        self.data_train_small = Subset(
            self.train_dataset, smaller_train_dataset_indexes
        )

        smaller_val_dataset_indexes = random.sample(
            sorted(val_indexes), self.hparams.batch_size * 10
        )
        self.data_val_small = Subset(self.train_dataset, smaller_val_dataset_indexes)

        smaller_test_dataset_indexes = random.sample(
            sorted(test_idx), self.hparams.batch_size * 10
        )
        self.data_test_small = Subset(self.full_dataset, smaller_test_dataset_indexes)

        self.data_train = Subset(self.train_dataset, train_indexes)
        self.data_val = Subset(self.train_dataset, val_indexes)

        # Apply normalization to data
        # Determine scaler based on training data
        # Apply scaling based on training data to val
        # Apply scaling based on training data to test

    def train_dataloader(self):
        return DataLoader(
            self.data_train_small,
            batch_size=self.hparams.batch_size,
            shuffle=True,
            num_workers=self.hparams.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.data_val_small,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
        )

    def test_dataloader(self):
        return DataLoader(
            self.data_test_small,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
        )

In [None]:
# Define the model
class CTGVAE(pl.LightningModule):
    """The class represneting the VAE model used for regression on CTGs."""

    def __init__(self, **hparams):
        super().__init__()
        self.save_hyperparameters()

        modules = []
        hidden_dims = [32, 64, 128, 256, 512]
        latent_dim = self.hparams.latent_size

        # Build Encoder
        in_channels = self.hparams.no_of_features
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv1d(
                        in_channels=in_channels,
                        out_channels=h_dim,
                        kernel_size=6,
                        stride=4,
                        padding=1,
                    ),
                    nn.BatchNorm1d(h_dim),
                    nn.LeakyReLU(),
                )
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1] * 2, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1] * 2, latent_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 2)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose1d(
                        hidden_dims[i],
                        hidden_dims[i + 1],
                        kernel_size=6,
                        stride=4,
                        padding=1,
                        output_padding=1,
                    ),
                    nn.BatchNorm1d(hidden_dims[i + 1]),
                    nn.LeakyReLU(),
                )
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose1d(
                hidden_dims[-1],
                hidden_dims[-1],
                kernel_size=16,
                stride=4,
                padding=1,
                output_padding=2,
            ),
            nn.BatchNorm1d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv1d(
                hidden_dims[-1],
                out_channels=self.hparams.no_of_features,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.Tanh(),
        )

    def reparametize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            noise = torch.randn_like(std)
            z = mu + noise * std
            return z
        else:
            return mu

    def forward(self, x):
        # encode input space to hidden space
        enc = self.encoder(x)
        enc = torch.flatten(enc, start_dim=1)

        # extract latent variable z(hidden space to latent space)
        mean = self.fc_mu(enc)
        logvar = self.fc_var(enc)

        z = self.reparametize(mean, logvar)  # batch_size x latent_size

        # decode latent space to input space
        z = self.decoder_input(z)
        z = z.view(-1, 512, 2)
        out = self.decoder(z)
        x_hat = self.final_layer(out)

        return mean, logvar, x_hat

    def loss_function(self, x_hat, x, mu, log_var) -> dict:
        # recons_loss = F.smooth_l1_loss(x_hat, x, reduction='mean')
        # kld_loss = -0.5 * torch.mean(1 + log_var - mu ** 2 - log_var.exp())

        if self.training:
            recons_loss = torch.sum(
                torch.sum(F.mse_loss(x_hat, x, reduction="none"), dim=(1, 2))
            )
        else:
            recons_loss = torch.mean(
                torch.sum(F.mse_loss(x_hat, x, reduction="none"), dim=(1, 2))
            )

        if self.training:
            kld_loss = torch.sum(
                -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
            )
        else:
            kld_loss = torch.mean(
                -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
            )

        # To account for minibatches
        batch_weight = x.shape[0] / self.hparams.batch_size
        kld_weight = self.get_cyclic_kl_annealing_weight()

        loss = (recons_loss + (kld_weight * kld_loss)) * batch_weight

        return {
            "loss": loss,
            "recon_loss": recons_loss,
            "kld": -kld_loss,
        }

    def configure_optimizers(self):
        opt = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
            eps=1e-4,
        )
        sch = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            opt, T_0=25, T_mult=1, eta_min=1e-9, last_epoch=-1
        )
        return [opt], [sch]

    def training_step(self, batch, batch_idx):
        x, _ = batch
        mu, log_var, x_hat = self.forward(x)
        loss_dict = self.loss_function(x_hat, x, mu, log_var)
        loss_dict = self.loss_function(x_hat, x, mu, log_var)
        self.calculate_and_log_metrics(loss_dict, "train", self.hparams.batch_size)
        return loss_dict

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var)
        self.calculate_and_log_metrics(loss_dict, "val", self.hparams.batch_size)
        if batch_idx == 0:
            self.plots_path = data_dir / ".." / ".." / "plots"
            self.plots_path.mkdir(exist_ok=True)
            if (self.current_epoch == 0) or (
                not (self.plots_path / "original_fhr.png").exists()
            ):
                orig_saved_filename = self.save_signal_figure(x[0], "original_fhr", 0)
                self.logger.experiment.log_artifact(
                    self.logger.run_id, orig_saved_filename
                )
            recon_saved_filename = self.save_signal_figure(
                x_out[0], "reconstructed_fhr", 0
            )
            self.logger.experiment.log_artifact(
                self.logger.run_id, recon_saved_filename
            )
        return loss_dict

    def test_step(self, batch, batch_idx):
        x, _ = batch
        mu, log_var, x_out = self.forward(x)
        loss_dict = self.loss_function(x_out, x, mu, log_var)
        self.calculate_and_log_metrics(loss_dict, "val", self.hparams.batch_size)
        return loss_dict

    def calculate_and_log_metrics(self, loss_dict, step_type, batch_size):
        self.log(
            f"kld_weight",
            self.get_cyclic_kl_annealing_weight(),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
            batch_size=batch_size,
        )
        self.log(
            f"{step_type}_loss",
            loss_dict["loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
            batch_size=batch_size,
        )
        self.log(
            f"{step_type}_kld_loss",
            loss_dict["kld"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
            batch_size=batch_size,
        )
        self.log(
            f"{step_type}_recon_loss",
            loss_dict["recon_loss"],
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
            batch_size=batch_size,
        )

    def save_signal_figure(self, signal_tensor, name, signal_dim):
        fhr_tensor = signal_tensor[signal_dim].cpu().detach().numpy()
        # toco_tensor = signal_tensor[1].cpu().detach().numpy()
        output_df = pd.DataFrame.from_dict({"FHR1": fhr_tensor})
        fig = visualize_ctg(ctg=output_df, toco_present=False)

        if "original" in name:
            saved_filename = f"{name}.png"
        else:
            saved_filename = f"{name}_epoch_{self.current_epoch+1}.png"

        saved_filepath = self.plots_path / saved_filename
        fig.write_image(saved_filepath)
        return saved_filepath

    def get_monotonic_kl_annealing_weight(self):
        if (self.current_epoch + 1) > self.hparams.kl_annealing_epoch:
            return 1
        else:
            return (self.current_epoch + 1) / (self.hparams.kl_annealing_epoch)

    def get_cyclic_kl_annealing_weight(self):
        current_epoch = self.current_epoch
        kl_annealing_epoch = self.hparams.kl_annealing_epoch
        is_annealing_step: bool = (
            math.ceil((current_epoch + 1) / (kl_annealing_epoch)) % 2 != 0
        ) and ((current_epoch + 1) % kl_annealing_epoch != 0)
        if is_annealing_step:
            return (current_epoch + 1) % (kl_annealing_epoch) / (kl_annealing_epoch)
        else:
            return 1

In [None]:
def train_model_with_hyperparams(
    hparams: OrderedDict,
    datamodule: pl.LightningDataModule,
    ckpt_path: Optional[str] = None,
):
    """Train the CTGCNN model given the hyperparameters as config and the number of epochs.
    The tune flag is used to specify if the training is standalone or part of a hyperparameter optimization process.
    The k_fold flag is used to specify if the training should be done as K fold cross validation.
    """

    # Clear out the previous models
    models_folder = Path(f"../models/")

    shutil.rmtree(models_folder)
    models_folder.mkdir(exist_ok=True, parents=True)

    mlflow_runs_folder = Path("../mlruns")
    mlflow_runs_folder.mkdir(exist_ok=True)

    callbacks = []

    val_checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor="val_loss",
        dirpath=models_folder,
        filename="models-{epoch:02d}-{val_loss:.2f}",
        save_top_k=3,
        mode="min",
        save_on_train_epoch_end=True,
    )

    periodic_checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=models_folder,
        filename="models-{epoch:02d}-{val_loss:.2f}",
        every_n_epochs=10,
    )

    # early_stopping_callback = EarlyStopping(monitor="val_loss", mode="min", patience=10)

    callbacks.append(val_checkpoint_callback)
    callbacks.append(periodic_checkpoint_callback)
    # callbacks.append(early_stopping_callback)

    mlf_logger = MLFlowLogger(
        experiment_name="ctg_vae_cnn",
        tracking_uri=mlflow_runs_folder.absolute().as_uri(),
    )
    ctg_cnn_model = CTGVAE(**hparams)
    trainer = pl.Trainer(
        logger=mlf_logger,
        max_epochs=hparams["num_epochs"],
        callbacks=callbacks,
        num_sanity_val_steps=0,
        devices=1,
        # strategy="ddp_notebook",
    )
    trainer.fit(model=ctg_cnn_model, datamodule=datamodule, ckpt_path=ckpt_path)
    trainer.test(model=ctg_cnn_model, datamodule=datamodule)

In [None]:
# Model Summary
hparams = OrderedDict(
    beta=0,
    lr=1e-4,
    weight_decay=1e-5,
    no_of_features=1,
    latent_size=24,
    batch_size=256,
)
model = CTGVAE(**hparams)
summary(model, input_size=(128, 1, 2400))

In [None]:
# Try training the model with some hyperparameters.
ctg_datamodule = CTGDataModule(
    data_dir=Path("/mnt/ssd2/ctg-analysis/ctg-data/final-ctg-dataset/"),
    split_seed=SEED_VALUE,
    num_workers=16,
    batch_size=256,
)
ctg_datamodule.prepare_data()
ctg_datamodule.setup()

In [None]:
# Try training the model with some hyperparameters.
hparams = OrderedDict(
    lr=1e-3,
    weight_decay=1e-5,
    no_of_features=1,
    latent_size=24,
    batch_size=128,
    kl_annealing_epoch=10,
    num_epochs=101,
)
train_model_with_hyperparams(
    hparams=hparams,
    datamodule=ctg_datamodule,
)