In [None]:
import scipy.io
import scipy.signal as sig
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import DataHandlers.CinCDataset as CinCDataset
import DataHandlers.SAFERDataset as SAFERDataset

import importlib
importlib.reload(SAFERDataset)
importlib.reload(CinCDataset)

from DataHandlers.DiagEnum import DiagEnum, feas1DiagToEnum

import matplotlib
matplotlib.rcParams["text.usetex"] = False

# A fudge because I moved the files
sys.modules["SAFERDataset"] = SAFERDataset
sys.modules["CinCDataset"] = CinCDataset

### Load the data

In [None]:
feas2_pt_data, feas2_ecg_data = SAFERDataset.load_feas_dataset(2, "dataframe")
feas2_ecg_data = feas2_ecg_data[feas2_ecg_data["length"] == 9120]

In [None]:
feas1_pt_data, feas1_ecg_data = SAFERDataset.load_feas_dataset(1, ecg_meas_diag=[d for d in DiagEnum if d != DiagEnum.Undecided])

In [None]:
# Load specially cleaned data

feas2_ecg_data = pd.read_pickle(r"C:\Users\daniel\Documents\2022_23_DSiromani\Feas2\ECGs\filtered_dataframe.pk")
feas2_pt_data = pd.read_csv(r"C:\Users\daniel\Documents\2022_23_DSiromani\Feas2\pt_data_anon.csv")
feas2_pt_data["ptID"] += 10000
feas2_ecg_data["ptID"] += 10000

feas1_ecg_data_clean = pd.read_pickle(r"C:\Users\daniel\Documents\2022_23_DSiromani\Feas1\ECGs\clean_ecg_dataset.pk")
feas1_pt_data = pd.read_csv(r"C:\Users\daniel\Documents\2022_23_DSiromani\Feas1\pt_data_anon.csv")
print(len(feas1_ecg_data_clean.index))

feas2_ecg_data_clean = pd.read_pickle(r"C:\Users\daniel\Documents\2022_23_DSiromani\Feas2\ECGs\clean_ecg_dataset.pk")
feas2_ecg_data_clean["ptID"] += 10000
print(len(feas2_ecg_data_clean.index))

all_clean_data = pd.concat([feas2_ecg_data_clean, feas1_ecg_data_clean], ignore_index=True)
all_clean_pt = pd.concat([feas2_pt_data[feas2_pt_data["ptID"].isin(feas2_ecg_data_clean["ptID"])], feas1_pt_data[feas1_pt_data["ptID"].isin(feas1_ecg_data_clean["ptID"])]])

all_clean_pt.set_index("ptID", drop=False, inplace=True)
all_clean_pt["noRecs"] = all_clean_data["ptID"].value_counts()
all_clean_pt["noHQrecs"] = all_clean_pt["noRecs"]
all_clean_pt.head()

### Setup the model

In [None]:
import torch.nn as nn
import torch

In [None]:
class CVAE(nn.Module):

    def __init__(self, z_dim):
        super(CVAE, self).__init__()
        self.z_dim = z_dim

        self.conv_section1 = nn.Sequential(
            nn.Conv1d(1, 16, 19, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(16),
            nn.Conv1d(16, 16, 19, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(16)
        )

        self.conv_section2 = nn.Sequential(
            nn.Conv1d(16, 16, 19, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(16),
            nn.Conv1d(16, 16, 19, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(16)
        )

        self.conv_section3 = nn.Sequential(
            nn.Conv1d(16, 32, 9, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Conv1d(32, 32, 9, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(32)
        )

        self.conv_section4 = nn.Sequential(
            nn.Conv1d(32, 32, 9, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Conv1d(32, 32, 9, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(32)
        )

        self.conv_section5 = nn.Sequential(
            nn.Conv1d(32, 64, 7, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(64, 64, 7, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(64)
        )

        """
        self.conv_section6 = nn.Sequential(
            nn.Conv1d(64, 64, 19, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Conv1d(64, 64, 19, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(64)
        )
        """

        self.conv_section7 = nn.Sequential(
            nn.Conv1d(64, 80, 7, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(80),
            nn.Conv1d(80, 80, 7, padding='same'),
            nn.ReLU(),
            nn.BatchNorm1d(80)
        )

        self.encoder_linear = nn.Linear(5120, z_dim*2)
        self.decoder_linear = nn.Linear(z_dim, 5120)

        self.decoder_batchnorm = nn.BatchNorm1d(5120)

        self.transconv_section1 = nn.Sequential(
            nn.ConvTranspose1d(16, 1, 19, padding=9, stride=1),
        )

        self.transconv_section2 = nn.Sequential(
            nn.ConvTranspose1d(16, 16, 20, padding=9, stride=2),
            nn.ReLU(),
            nn.BatchNorm1d(16),
        )

        self.transconv_section3 = nn.Sequential(
            nn.ConvTranspose1d(32, 16, 20, padding=9, stride=2),
            nn.ReLU(),
            nn.BatchNorm1d(16),
        )

        self.transconv_section4 = nn.Sequential(
            nn.ConvTranspose1d(48, 32, 10, padding=4, stride=2),
            nn.ReLU(),
            nn.BatchNorm1d(32),
        )

        self.transconv_section5 = nn.Sequential(
            nn.ConvTranspose1d(64, 48, 10, padding=4, stride=2),
            nn.ReLU(),
            nn.BatchNorm1d(48),
        )

        self.transconv_section6 = nn.Sequential(
            nn.ConvTranspose1d(64, 64, 8, padding=3, stride=2),
            nn.ReLU(),
            nn.BatchNorm1d(64)
        )

        self.transconv_section7 = nn.Sequential(
            nn.ConvTranspose1d(80, 64, 7, padding=3, stride=1),
            nn.ReLU(),
            nn.BatchNorm1d(64)
        )

        self.dropout = nn.Dropout()

    def encode(self, x):
        # [1, 2048]
        x = self.conv_section1(x)
        x = nn.functional.max_pool1d(x, 2)

        # [16, 1024]
        x = self.conv_section2(x) + x
        x = nn.functional.max_pool1d(x, 2)

        # [32, 512]
        x = self.conv_section3(x)
        x = nn.functional.max_pool1d(x, 2)

        # [32, 256]
        x = self.conv_section4(x) + x
        x = nn.functional.max_pool1d(x, 2)

        # [64, 128]
        x = self.conv_section5(x)
        x = nn.functional.max_pool1d(x, 2)

        # [64, 64]
        # x = self.conv_section6(x)

        # [64, 64]
        x = self.conv_section7(x)

        # [80, 64]
        x = torch.flatten(x, -2)

        # [5120]
        x = self.encoder_linear(x)
        return x

    def decode(self, z):
        # [z_dim]
        z = self.decoder_linear(z)
        z = self.decoder_batchnorm(z)
        z = torch.nn.functional.relu(z)

        # [5120]
        z = torch.reshape(z, (-1, 80, 64))
        # [80, 64]
        z = self.transconv_section7(z)
        # print(z.shape)
        # [64, 64]
        z = self.transconv_section6(z)
        # print(z.shape)
        # [64, 64]
        z = self.transconv_section5(z)
        # print(z.shape)
        # [64, 128]
        z = self.transconv_section4(z)
        # print(z.shape)
        # [48, 256]
        z = self.transconv_section3(z)
        # print(z.shape)
        # [32, 512]
        z = self.transconv_section2(z)
        # print(z.shape)
        # [16, 1024]
        z = self.transconv_section1(z)
        # print(z.shape)
        # [1, 2048]
        return z


    def forward(self, x):
        batch_size = x.shape[0]
        device = x.device

        z_dist = self.encode(x)

        z = torch.randn((batch_size, self.z_dim)).to(device) * torch.abs(z_dist[:, self.z_dim:]) + z_dist[:, :self.z_dim]

        x = self.decode(z)

        return x, z_dist

In [None]:
# Onehot encoding
from torch.utils.data import Dataset, DataLoader

class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, dataset):
        'Initialization'
        self.dataset = dataset

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset.index)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        row = self.dataset.iloc[index]

        X = row["data"]
        y = row["class_index"]
        ind = row.name

        return X, y, ind

In [None]:
def split_to_segments(dataset, new_len, orig_len, overlap=0):
    sections = []

    step = int(round(new_len * (1 - overlap)))
    num_sections = (orig_len - (new_len - step)) // step
    for _, series in dataset.iterrows():
        for i in range(num_sections):
            section_series = series.copy()
            section_series["data"] = section_series["data"][i*step: i*step + new_len]
            section_series["rec_ind"] = series.name
            section_series["rec_pos"] = i
            # Keep all other data (ptid, measDiag etc the same for each section as the source ECG)
            sections.append(section_series)

    return pd.DataFrame(sections).reset_index()

In [None]:
# For MIT noise stress test database
import wfdb
import os
import scipy.signal

noise_stress_test_db = "mit-bih-noise-stress-test-database"
records = ["118", "119"]
stress_test_files = ["{}e24", "{}e18", "{}e12", "{}e06", "{}e00", "{}e_6"]

noise_level = []
segments_lists = []

# Additionally band pass filter
def filter_ecg(x, fs):
    b, a = scipy.signal.butter(3, [0.66, 30], 'band', fs=fs)
    x = scipy.signal.filtfilt(b, a, x, padlen=150)
    x = (x - min(x)) / (max(x) - min(x))
    return x

for file in stress_test_files:
    try:
        print(f"Reading file: {file}")
        data = wfdb.io.rdrecord(os.path.join(noise_stress_test_db, file.format(records[0])))
        all_data_v1 = data.p_signal[:,1]
        # Resample to 300Hz
        fs = 300
        all_data_v1 = scipy.signal.resample(all_data_v1, int(all_data_v1.shape[0] * fs/data.fs))
        all_data_v1 = filter_ecg(all_data_v1, 300)
        # all_data_v1 = adaptive_gain_norm(all_data_v1, 501)

        data = wfdb.io.rdrecord(os.path.join(noise_stress_test_db, file.format(records[1])))
        all_data_v2 = data.p_signal[:,1]
        # Resample to 300Hz
        fs = 300
        all_data_v2 = scipy.signal.resample(all_data_v2, int(all_data_v2.shape[0] * fs/data.fs))
        all_data_v2 = filter_ecg(all_data_v2, 300)

        segments = []
        noise_boundaries = np.arange(5 * fs, all_data_v1.shape[-1], 240 * fs)
        for bound in noise_boundaries:
            segments.append(all_data_v1[bound: bound + 120 * fs])

        noise_boundaries_2 = np.arange(5 * fs, all_data_v2.shape[-1], 240 * fs)
        for bound in noise_boundaries:
            segments.append(all_data_v2[bound: bound + 120 * fs])

        noise_level.append(file.split("e")[-1])
        segments_lists.append(segments)
    except ValueError:
        print("error, scipping file")
        continue

segments_lists = np.array(segments_lists)

data = np.concatenate([segments_lists[[0, i]].copy() for i in range(1, 6)], axis=1)

noise_level = np.array(noise_level[1:])
noise_level = np.repeat(noise_level, len(segments_lists[0]))

data = np.transpose(data, axes=(1, 2, 0))

print(data.shape)
print(noise_level.shape)
nst_df = pd.DataFrame({"data": [data[i] for i in range(data.shape[0])], "noise_level": noise_level})
print(nst_df.head())

pk_path = "mit-bih-noise-stress-test-database/database_denoising.pk"
nst_df.to_pickle(pk_path)

In [None]:

pk_path = "mit-bih-noise-stress-test-database/database_denoising.pk"
nst_df = pd.read_pickle(pk_path)
# normalise
nst_df["data"] = (nst_df["data"] - nst_df["data"].map(lambda x: x.mean(axis=0)))/nst_df["data"].map(lambda x: x.std(axis=0))

## Train test split here
train_dataset = nst_df.sample(frac=0.8)
test_dataset = nst_df[~nst_df.index.isin(train_dataset.index)]

train_dataset = split_to_segments(train_dataset, 2048, 36000, 0.5)
test_dataset = split_to_segments(test_dataset, 2048, 36000, 0.5)

class NSTDataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, dataset):
        'Initialization'
        self.dataset = dataset

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset.index)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        row = self.dataset.iloc[index]

        X = row["data"][:, 1]
        X_clean = row["data"][:, 0]

        return X, X_clean, row.name

torch_dataset_train = NSTDataset(train_dataset)
torch_dataset_test = NSTDataset(test_dataset)
train_dataloader = DataLoader(torch_dataset_train, batch_size=32, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(torch_dataset_test, batch_size=32, shuffle=True, pin_memory=True)

In [None]:
for (signals, clean_signals, ind) in test_dataloader:
    plt.plot(signals[0])
    plt.plot(clean_signals[0])
    plt.plot(test_dataset.loc[int(ind[0])]["data"][:, 0])
    plt.show()

In [None]:
# For SAFER data
# Split train and test data according to each patient
def make_SAFER_dataloaders(pt_data, ecg_data, test_frac, only_clean_training=True):
    pt_data["noLQrecs"] = pt_data["noRecs"] - pt_data["noHQrecs"]  # for Feas1 this might include stuff flagged by zenicor as noisy?
    train_patients = []
    test_patients = []

    for val, df in pt_data.groupby("noLQrecs"):
        # print(f"processing {val}")
        # print(f"number of patients {len(df.index)}")
        test = df.sample(frac=test_frac)
        test_patients.append(test)
        train_patients.append(df[~df["ptID"].isin(test["ptID"])])

    train_pt_df = pd.concat(train_patients)
    test_pt_df = pd.concat(test_patients)

    print(f"Test high quality: {test_pt_df['noHQrecs'].sum()} low quality: {test_pt_df['noLQrecs'].sum()} ")
    print(f"Train high quality: {train_pt_df['noHQrecs'].sum()} low quality: {train_pt_df['noLQrecs'].sum()} ")

    train_dataloader = None
    test_dataloader = None

    train_dataset = None
    test_dataset = None

    if not train_pt_df.empty:
        # get ECG datasets
        train_dataset = split_to_segments(ecg_data[ecg_data["ptID"].isin(train_pt_df["ptID"])], 2048, 9120, 0.5)
        # Normalise
        train_dataset["data"] = (train_dataset["data"] - train_dataset["data"].map(lambda x: x.mean()))/train_dataset["data"].map(lambda x: x.std())

        if only_clean_training:
            torch_dataset_train = Dataset(train_dataset[train_dataset["class_index"] == 0])
        else:
            torch_dataset_train = Dataset(train_dataset)

        train_dataloader = DataLoader(torch_dataset_train, batch_size=64, shuffle=True, pin_memory=True)

    if not test_pt_df.empty:
        test_dataset = split_to_segments(ecg_data[ecg_data["ptID"].isin(test_pt_df["ptID"])], 2048, 9120, 0.5)
        test_dataset["data"] = (test_dataset["data"] - test_dataset["data"].map(lambda x: x.mean()))/test_dataset["data"].map(lambda x: x.std())
        torch_dataset_test = Dataset(test_dataset)
        test_dataloader = DataLoader(torch_dataset_test, batch_size=64, shuffle=True, pin_memory=True)

    return train_dataloader, test_dataloader, train_dataset, test_dataset


train_dataloader, test_dataloader, train_dataset, test_dataset = make_SAFER_dataloaders(feas2_pt_data, feas2_ecg_data, test_frac=0.2, only_clean_training=False)

In [None]:
# If we want noisy and clean test_data for evaluation, after training and testing on only clean data in training loop
_, noisy_test_dataloader, _, noisy_test_dataset = make_SAFER_dataloaders(feas2_pt_data, feas2_ecg_data[~feas2_ecg_data["measID"].isin(train_dataset["measID"])], test_frac=1, only_clean_training=False)

In [None]:
# Setup dataloaders for only the clean data
train_dataloader, test_dataloader, train_dataset, test_dataset = make_SAFER_dataloaders(all_clean_pt, all_clean_data, test_frac=0.2, only_clean_training=False)

### GMM latent space prior - not sure if this does any good

In [None]:
num_gmm_components = 10
gmm_means = torch.randn((num_gmm_components, 60))
gmm_stds = torch.ones((num_gmm_components, 60))
gmm_mixture_weights = (torch.ones(num_gmm_components)/num_gmm_components)

def kl_gauss(z_m, z_std, t_m, t_std):
    z_var = z_std ** 2
    t_var = t_std ** 2

    term1 = torch.sum(torch.log(t_var)[None, :] - torch.log(z_var), dim=-1)
    term2 = torch.sum(z_var/t_var[None, :], dim=-1) - z_m.shape[-1]
    term3 = torch.sum(((z_m - t_m[None, :]) ** 2) * 1/t_var[None, :])

    return (1/2) * (term1 + term2 + term3)


def gmm_kl_latent_loss(z_m, z_std):

    batch_size = z_m.shape[0]
    kl_divs = torch.zeros((batch_size, num_gmm_components))
    for i in range(num_gmm_components):
        kl_divs[:, i] = kl_gauss(z_m, z_std, gmm_means[i], gmm_stds[i])

    kl_divs = kl_divs * 1/500

    return torch.mean(torch.log(1/torch.matmul(torch.exp(-kl_divs), gmm_mixture_weights)))

gmm_kl_latent_loss(torch.zeros((32, 60)), torch.ones((32, 60)))

In [None]:
num_epochs = 25

if torch.cuda.is_available():
    print("Using Cuda")
    device = torch.device("cuda")
else:
    print("Using CPU")
    device = torch.device("cpu")

z_dim = 128

model = CVAE(z_dim).to(device)

# Use weightings to avoid

# class_counts = torch.tensor(dataset["class_index"].value_counts().values.astype(np.float32))
# class_weights = torch.nn.functional.normalize(1.0/class_counts, dim=0)


def kl_latent_loss(z_mean, z_std):
    # The regularization loss based on kl divergence of the latent distribution from N(0, 1)
    vars = z_std ** 2
    means = z_mean

    return 1/500 * torch.mean( - torch.log(vars) + vars + means ** 2 - 1)

mse_loss = torch.nn.MSELoss()
loss_func = lambda x, s, z: kl_latent_loss(z[:, :z_dim], z[:, z_dim:]) +  mse_loss(x, s)

optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.8)
num_batches = len(train_dataloader)
num_test_batches = len(test_dataloader)

In [None]:
def dist(x, y):
    return np.sqrt(np.sum((x - y) ** 2))

triplet_margin = 0.02

def triplet_latent_loss(za, zp, zn):
    return max(dist(za, zp) - dist(za, zn) + triplet_margin, 0)

### Train on feas 1 and 2

In [None]:
import math

# warning: changing these chunk sizes may reload feas1 data from scratch, which will take ages
chunk_size = 20000
num_chunks = math.ceil(162515 / chunk_size )

def get_feas1_dataloader(chunk_num):
    feas1_pt_data, feas1_ecg_data = SAFERDataset.load_feas_dataset(1, f"dataframe_{chunk_num}.pk", ecg_range=[chunk_size * chunk_num, chunk_size * (chunk_num + 1)])
    train_dataset = split_to_segments(feas1_ecg_data, 2048, 9120, 0.5)
    train_dataset["data"] = (train_dataset["data"] - train_dataset["data"].map(lambda x: x.mean()))/train_dataset["data"].map(lambda x: x.std())

    torch_dataset_train = Dataset(train_dataset)
    train_dataloader = DataLoader(torch_dataset_train, batch_size=128, shuffle=True, pin_memory=True)

    return train_dataloader

In [None]:
model = model.to(device)
import copy

best_test_loss = 100
best_model = copy.deepcopy(model).cpu()

for epoch in range(num_epochs):
    total_loss = 0
    print(f"starting epoch {epoch} ...")
    # Train
    model.train()

    for ds_ind in range(num_chunks + 1):
        print(f"training on dataset: {ds_ind}")
        if ds_ind == 0:
            train_dataloader_part = train_dataloader
        else:
            train_dataloader_part = get_feas1_dataloader(ds_ind-1)

        for i, (signals, _, _) in enumerate(train_dataloader_part):
            signals = torch.unsqueeze(signals.to(device), 1).float()

            optimizer.zero_grad()
            output, latents = model(signals)
            loss = loss_func(output, signals, latents)

            loss.backward()

            nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2)

            optimizer.step()
            total_loss += float(loss)

        print(f"Total loss {total_loss/num_batches}")

    print(f"Epoch {epoch} finished with average loss {total_loss/num_batches}")
    print("Testing ...")
    # Test
    test_loss = 0
    with torch.no_grad():
        model.eval()
        for i, (signals, _, _)in enumerate(test_dataloader):
            signals = torch.unsqueeze(signals.to(device), 1).float()

            output, latents = model (signals)
            loss = loss_func(output, signals, latents)
            test_loss += float(loss)

    print(f"Average test loss: {test_loss/num_test_batches}")

    if test_loss/num_test_batches < best_test_loss:
        best_model = copy.deepcopy(model).cpu()
        best_test_loss = test_loss/num_test_batches

model = best_model

### Train only using feas2

In [None]:
model = model.to(device)
import copy

best_test_loss = 100
best_model = copy.deepcopy(model).cpu()

for epoch in range(num_epochs):
    total_loss = 0
    print(f"starting epoch {epoch} ...")
    # Train
    model.train()
    for i, (signals, _, _) in enumerate(train_dataloader):
        signals = torch.unsqueeze(signals.to(device), 1).float()

        optimizer.zero_grad()
        output, latents = model(signals)
        loss = loss_func(output, signals, latents.to("cpu"))

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2)

        optimizer.step()
        total_loss += float(loss)

    print(f"Epoch {epoch} finished with average loss {total_loss/num_batches}")
    print("Testing ...")
    # Test
    test_loss = 0
    with torch.no_grad():
        model.eval()
        for i, (signals, _, _)in enumerate(test_dataloader):
            signals = torch.unsqueeze(signals.to(device), 1).float()

            output, latents = model (signals)
            loss = loss_func(output, signals, latents.to("cpu"))
            test_loss += float(loss)

    print(f"Average test loss: {test_loss/num_test_batches}")

    if test_loss/num_test_batches < best_test_loss:
        best_model = copy.deepcopy(model).cpu()
        best_test_loss = test_loss/num_test_batches

model = best_model

### Train using the NST dataset

In [None]:
model = model.to(device)
import copy

best_test_loss = 100
best_model = copy.deepcopy(model).cpu()

for epoch in range(num_epochs):
    total_loss = 0
    print(f"starting epoch {epoch} ...")
    # Train
    model.train()
    for i, (signals, clean_signals, _) in enumerate(train_dataloader):
        signals = torch.unsqueeze(signals.to(device), 1).float()
        clean_signals = torch.unsqueeze(clean_signals.to(device), 1).float()

        optimizer.zero_grad()
        output, latents = model(signals)
        loss = loss_func(output, clean_signals, latents.to("cpu"))

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2)

        optimizer.step()
        total_loss += float(loss)

    print(f"Epoch {epoch} finished with average loss {total_loss/num_batches}")
    print("Testing ...")
    # Test
    test_loss = 0
    with torch.no_grad():
        model.eval()
        for i, (signals, clean_signals, _)in enumerate(test_dataloader):
            signals = torch.unsqueeze(signals.to(device), 1).float()
            clean_signals = torch.unsqueeze(clean_signals.to(device), 1).float()

            output, latents = model (signals)
            loss = loss_func(output, clean_signals, latents.to("cpu"))
            test_loss += float(loss)

    print(f"Average test loss: {test_loss/num_test_batches}")

    if test_loss/num_test_batches < best_test_loss:
        best_model = copy.deepcopy(model).cpu()
        best_test_loss = test_loss/num_test_batches

model = best_model

In [None]:
model = model.to(device)  # if train finished use this to put back on the GPU

In [None]:
model = best_model.to(device)  # if train did not finish use this to take the best intermediate result

In [None]:
# Save a model
torch.save(model.state_dict(), "TrainedModels/Autoencoder_50_epochs_nst.pt")
train_dataset.to_pickle("TrainedModels/Autoencoder_50_epochs_nst_train_set.pk")

In [None]:
z_dim = 128
model = CVAE(z_dim).to(device)
model.load_state_dict(torch.load("TrainedModels/Autoencoder_new_6_epochs_all_feas1_feas2.pt", map_location=device))

In [None]:
# Reload the training and dataset with the model so we don't test on stuff we trained on
train_dataset = pd.read_pickle("TrainedModels/Autoencoder_new_6_epochs_all_feas1_feas2_train_set.pk")

train_dataset["data"] = (train_dataset["data"] - train_dataset["data"].map(lambda x: x.mean()))/train_dataset["data"].map(lambda x: x.std())
torch_dataset_train = Dataset(train_dataset)
train_dataloader = DataLoader(torch_dataset_train, batch_size=32, shuffle=True, pin_memory=True)

test_pt_df = feas2_pt_data[~feas2_pt_data["ptID"].isin(train_dataset["ptID"])]

if not test_pt_df.empty:
    test_dataset = test_dataset[test_dataset["measDiag"] != DiagEnum.Undecided]
    test_dataset = split_to_segments(feas2_ecg_data[feas2_ecg_data["ptID"].isin(test_pt_df["ptID"])], 2048, 9120, 0.5)
    test_dataset["data"] = (test_dataset["data"] - test_dataset["data"].map(lambda x: x.mean()))/test_dataset["data"].map(lambda x: x.std())
    torch_dataset_test = Dataset(test_dataset)
    test_dataloader = DataLoader(torch_dataset_test, batch_size=32, shuffle=True, pin_memory=True)

### Reconstruction for clean samples

In [None]:
# Plot test data reconstruction
test_dataset["reconstruction"] = None
mse_only_loss = lambda truth, pred: torch.mean((truth - pred) ** 2, dim=(1,2))

with torch.no_grad():
    model.eval()

    r_err = []
    inds = []
    reconstructions = []

    for i, (signals, _, ind) in enumerate(test_dataloader):
        signals = torch.unsqueeze(signals.to(device), 1).float()
        # signals_clean = torch.unsqueeze(signals_clean.to(device), 1).float()
        # labels = labels.type(torch.LongTensor)

        output, latents  = model(-signals)
        loss = mse_only_loss(output, signals).detach().cpu().numpy()

        output = output.detach().cpu().numpy()

        for i, o, l in zip(ind, output[:, 0, :], loss):
            r_err.append(l)
            reconstructions.append(o)
            inds.append(int(i))


test_dataset["r_err"] = pd.Series(data=r_err, index=inds)
test_dataset["reconstruction"] = pd.Series(data=reconstructions, index=inds)

In [None]:
test_df = test_dataset

In [None]:
test_df.head()

In [None]:
test_df["reconstruction"].iloc[0]

In [None]:
from matplotlib.ticker import AutoMinorLocator
import matplotlib
matplotlib.use('TkAgg')

def plot_ecg_and_reconstruction(x, r, fs=300, n_split=3):
    sample_len = x.shape[0]
    time_axis = np.arange(sample_len)/fs

    cuts = np.round(np.linspace(0, sample_len-1, n_split+1)).astype(int)

    fig, ax = plt.subplots(n_split, 1, figsize=(16, 10), squeeze=False)
    for j in range(n_split):
        ax[j][0].plot(time_axis[cuts[j]:cuts[j+1]], x[cuts[j]:cuts[j+1]])
        ax[j][0].plot(time_axis[cuts[j]:cuts[j+1]], r[cuts[j]:cuts[j+1]])
        ax[j][0].set_xlabel("Time")
        ax[j][0].set_xlim((time_axis[cuts[j]], time_axis[cuts[j+1]]))

        t_s = time_axis[cuts[j]]
        t_f = time_axis[cuts[j+1]]
        time_ticks = np.arange(t_s - t_s%0.2, t_f + (0.2 - t_f%0.2), 0.2)
        decimal_labels = ~np.isclose(time_ticks, np.round(time_ticks))
        time_labels = np.round(time_ticks).astype(int).astype(str)
        time_labels[decimal_labels] = ""

        ax[j][0].set_xticks(time_ticks, time_labels)

        ax[j][0].xaxis.set_major_formatter(plt.NullFormatter())
        ax[j][0].yaxis.set_major_formatter(plt.NullFormatter())

        ax[j][0].xaxis.set_minor_locator(AutoMinorLocator(5))
        ax[j][0].yaxis.set_minor_locator(AutoMinorLocator(5))

        ax[j][0].grid(which='major', linestyle='-', linewidth='0.5', color='black')
        ax[j][0].grid(which='minor', linestyle='-', linewidth='0.5', color='lightgray')

    plt.show()

for _, ecg in test_df[test_df["measDiag"] == DiagEnum.NoAF].iterrows():
    # print(ecg)
    print(ecg[["ptDiag", "measDiag", "tag_orig_Poor_Quality", "poss_AF_tag", "r_err"]])
    plot_ecg_and_reconstruction(ecg["data"], -ecg["reconstruction"], n_split=1)

In [None]:
def plot_ecg_and_reconstruction_for_classes(xs, rs, titles, fs=300):
    fig, ax = plt.subplots(len(xs), 1, figsize=(6, 7))

    for j, (x, r, t) in enumerate(zip(xs, rs, titles)):
        sample_len = x.shape[0]
        time_axis = np.arange(sample_len)/fs

        ax[j].plot(time_axis, x)
        ax[j].plot(time_axis, r)
        ax[j].set_xlabel("Time")
        ax[j].set_xlim((time_axis[0], time_axis[-1]))

        ax[j].set_xticks(np.arange(time_axis[0], time_axis[-1]+0.2,0.2))
        ax[j].set_title(t)

        ax[j].xaxis.set_major_formatter(plt.NullFormatter())
        ax[j].yaxis.set_major_formatter(plt.NullFormatter())

        ax[j].xaxis.set_minor_locator(AutoMinorLocator(5))
        ax[j].yaxis.set_minor_locator(AutoMinorLocator(5))

        ax[j].grid(which='major', linestyle='-', linewidth='0.5', color='black')
        ax[j].grid(which='minor', linestyle='-', linewidth='0.5', color='lightgray')

    fig.tight_layout()
    # plt.show()
    plt.savefig("TMRFigures/cvae_reconst_examples_large_dataset.png")

ecg_ind_list = [3195, 2916, 1563, 1561]  # 2192 # 441 # 315

xs = test_df.loc[ecg_ind_list]["data"].tolist()
rs = test_df.loc[ecg_ind_list]["reconstruction"].tolist()
titles = test_df.loc[ecg_ind_list].apply(lambda x: f"{x['measDiag'].name} e = {x['r_err']:.3f}", axis=1)   # ["measDiag"].map(lambda x: x.name).tolist()
print(len(titles))

plot_ecg_and_reconstruction_for_classes(xs, rs, titles)

### Latent space exploration

In [None]:
# Try some latent space exploration

with torch.no_grad():
    model.eval()
    for i, (signals, _, _) in enumerate(test_dataloader):
        signals = torch.unsqueeze(signals.to(device), 1).float()
        # fft = torch.abs(torch.fft.fft(signals))
        # signals = torch.cat([signals, fft], dim=1)
        # labels = labels.type(torch.LongTensor)

        latent_position = model.encode(signals)
        latent_position = latent_position.detach().cpu().numpy()
        signals_np = signals.detach().cpu().numpy()

        print(latent_position.shape)

        break

index = 3
latent_positions = np.zeros((10, *latent_position.shape), dtype=np.float32)
for i in range(10):
    latent_positions[i, :, :] += latent_position
    latent_positions[i, :, index] = i * 4 - 2

signals = []

with torch.no_grad():
    model.eval()
    for l in latent_positions:
        latent = torch.from_numpy(l[:, :60]).to(device)
        signal = model.decode(latent)
        signals.append(signal.detach().cpu().numpy())

### Interpolate between a noisy and clean ECG!

In [None]:
# noisy 441 # clean 315

noisy_latent = test_dataset.loc[441]["latent_encoding"][:60]
clean_latent = test_dataset.loc[315]["latent_encoding"][:60]

latent_sequence = np.linspace(noisy_latent, clean_latent, 32)
latent_sequence = torch.from_numpy(latent_sequence).to(device)

ecgs = model.decode(latent_sequence).detach().cpu().numpy()
print("plotting")

for ecg in np.flip(ecgs[:, 0, :], axis=0):
    plt.plot(ecg)
    plt.show()

### Find the reconstruction error for noisy and clean samples

In [None]:
# For NST data
# test_whole_ecgs_rec_err = test_dataset.groupby("rec_ind").agg({"r_err": "mean", "noise_level": lambda x: x.iloc[0]})
# For safer data
test_whole_ecgs_rec_err = test_dataset.groupby("rec_ind").agg({"r_err": "mean" ,"measDiag": lambda x: x.iloc[0]})
test_whole_ecgs_rec_err

In [None]:
# A little bit of faff to plot the results from small and large datasets on one axis, for the TMR, not sure it made it into the report in the end
test_whole_ecgs_rec_err.to_pickle("TrainedModels/Autoecoder_small_dataset.pk")

In [None]:
test_not_undecided = test_whole_ecgs_rec_err

fig = plt.figure(figsize=(6, 4), dpi=300)
plt.scatter(test_not_undecided["measDiag"].map(lambda x: x.value), test_not_undecided["r_err"], marker='+')
plt.xticks([e.value for e in pd.unique(test_not_undecided["measDiag"])], [e.name for e in pd.unique(test_not_undecided["measDiag"])])
plt.ylabel("Reconstruction error")
plt.xlabel("Measurement diagnosis")
plt.tight_layout()
plt.show()

# plt.savefig("TMRFigures/cvae_reconst_err_large_dataset.png")

In [None]:
test_whole_ecgs_rec_err["measDiag"].value_counts()

In [None]:
# Safer data

test_not_undecided = test_whole_ecgs_rec_err
# test_not_undecided_2 = pd.read_pickle("TrainedModels/Autoecoder_large_dataset.pk")

fig = plt.figure(figsize=(6, 4), dpi=300)

enum_order = [DiagEnum.NoAF, DiagEnum.AF, DiagEnum.PoorQuality]
data = [test_not_undecided[test_not_undecided["measDiag"] == e]["r_err"] for e in enum_order]

# print(test_not_undecided_2["measDiag"].value_counts())
# data_2 = [test_not_undecided_2[test_not_undecided_2["measDiag"] == e]["r_err"] for e in enum_order]

plt.violinplot(data)    # quantiles=[[0.25, 0.75]]*4, showmedians=True)
# plt.violinplot(data_2)
plt.xticks([1, 2, 3], [e.name for e in enum_order])
plt.ylabel("Reconstruction error")
plt.xlabel("Measurement diagnosis")
plt.tight_layout()
plt.show()

# plt.savefig("TMRFigures/cvae_reconst_err_small_dataset.png")

In [None]:
# NST data

test_not_undecided = test_whole_ecgs_rec_err

fig = plt.figure(figsize=(6, 4), dpi=300)

noise_levels = np.sort(pd.unique(test_not_undecided["noise_level"]))
print(noise_levels)
data = [test_not_undecided[test_not_undecided["noise_level"] == e]["r_err"] for e in noise_levels]


plt.violinplot(data)    # quantiles=[[0.25, 0.75]]*4, showmedians=True)
plt.xticks([1, 2, 3, 4, 5])
plt.ylabel("Reconstruction error")
plt.xlabel("Measurement diagnosis")
plt.tight_layout()
plt.show()

# plt.savefig("TMRFigures/cvae_reconst_err_small_dataset.png")

### Sample cross validation code for SAFER (not yet applied to anything)

## Classification from the latent space

In [None]:
from sklearn.svm import SVC

with torch.no_grad():
    model.eval()
    latents = []
    inds = []

    for i, (signals, _, ind) in enumerate(train_dataloader):
        signals = torch.unsqueeze(signals.to(device), 1).float()
        # fft = torch.abs(torch.fft.fft(signals))
        # signals = torch.cat([signals, fft], dim=1)
        # labels = labels.type(torch.LongTensor)

        latent_position = model.encode(signals)
        latent_position = latent_position.detach().cpu().numpy()
        signals_np = signals.detach().cpu().numpy()

        for i, l in zip(ind, latent_position):
            latents.append(l)
            inds.append(i)

train_dataset["latent_encoding"] = pd.Series(data=latents, index=inds)
svc_train_df = train_dataset.dropna(subset=["latent_encoding"])

In [None]:
svc_train_df = svc_train_df[svc_train_df["measDiag"] != DiagEnum.Undecided]
train_dataset

#### Visualise the data with scatter plots and T-SNE

In [None]:
latent_list = list(svc_train_df["latent_encoding"].map(lambda x: x[:60].tolist()).values)
latent_df = pd.DataFrame(latent_list, index=svc_train_df.index)
print(latent_df.columns)

latent_ind = 0

# scatter plot
for i in range(60):
    plt.figure(figsize=(6, 4), dpi=300)
    for d in [DiagEnum.NoAF, DiagEnum.PoorQuality, DiagEnum.AF, DiagEnum.CannotExcludePathology]:
        plt.scatter(latent_df[svc_train_df["measDiag"] == d][0], latent_df[svc_train_df["measDiag"] == d][i], marker="x", label=d.name)
    plt.legend()
    plt.ylabel(f"latent mean {i}")
    plt.xlabel(f"latent mean 0")
    plt.show()

In [None]:
from sklearn.manifold import TSNE

latent_matrix = np.array(list(svc_train_df["latent_encoding"].map(lambda x: x[:60].tolist()).values))
latent_classes = svc_train_df["measDiag"].values

print("starting tsne")
tsne = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=30)

X_embedded = tsne.fit_transform(latent_matrix)

for d in [DiagEnum.NoAF, DiagEnum.PoorQuality, DiagEnum.AF, DiagEnum.CannotExcludePathology]:
    plt.scatter(X_embedded[latent_classes == d, 0], X_embedded[latent_classes == d, 1], marker="x", label=d.name)

plt.legend()
plt.show()

#### Group all the segments together

In [None]:
svc_train_df = svc_train_df[svc_train_df["measDiag"] != DiagEnum.Undecided]

def concatenate_means(x):
    mean_series = x.map(lambda x: x[:60])
    return np.concatenate(mean_series.tolist())

full_ecg_train_df = svc_train_df.groupby("rec_ind").agg({"latent_encoding": concatenate_means, "measDiag": lambda x: x.iloc[0], "class_index": lambda x: x.iloc[0]})
full_ecg_train_df.iloc[0]["latent_encoding"].shape

In [None]:
# Try a T-SNE now all the segments are together

latent_matrix = np.array(list(full_ecg_train_df["latent_encoding"].map(lambda x: x.tolist()).values))
latent_classes = full_ecg_train_df["measDiag"].values

print("starting tsne")
tsne = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=30)

X_embedded = tsne.fit_transform(latent_matrix)

for d in [DiagEnum.NoAF, DiagEnum.PoorQuality, DiagEnum.AF, DiagEnum.CannotExcludePathology]:
    plt.scatter(X_embedded[latent_classes == d, 0], X_embedded[latent_classes == d, 1], marker="x", label=d.name)

plt.legend()
plt.show()

In [None]:
train_matrix = np.vstack(full_ecg_train_df["latent_encoding"].values)
targets =  np.array(full_ecg_train_df["class_index"].astype(int).values)

print(train_matrix.shape)

# class weightings?
classifier = SVC()
classifier = classifier.fit(train_matrix, targets)

### Testing

In [None]:
test_dataset["latent_encoding"] = None
inds = []
latents = []

with torch.no_grad():
    model.eval()
    for i, (signals, _,  ind) in enumerate(test_dataloader):
        signals = torch.unsqueeze(signals.to(device), 1).float()
        # fft = torch.abs(torch.fft.fft(signals))
        # signals = torch.cat([signals, fft], dim=1)
        # labels = labels.type(torch.LongTensor)

        latent_position = model.encode(signals)
        latent_position = latent_position.detach().cpu().numpy()

        for i, l in zip(ind, latent_position):
            inds.append(int(i))
            latents.append(l)

test_dataset["latent_encoding"] = pd.Series(data=latents, index=inds)

In [None]:
full_ecg_test_df = test_dataset.groupby("rec_ind").agg({"latent_encoding": concatenate_means, "measDiag": lambda x: x.iloc[0], "class_index": lambda x: x.iloc[0], "measID": lambda x: x.iloc[0]})
full_ecg_no_undecided_test_df = full_ecg_test_df[full_ecg_test_df["measDiag"] != DiagEnum.Undecided]

test_matrix = np.vstack(full_ecg_no_undecided_test_df["latent_encoding"].values)
targets =  np.array(full_ecg_no_undecided_test_df["class_index"].astype(int).values)
print(test_matrix.shape)

prediction = classifier.predict(test_matrix)

full_ecg_no_undecided_test_df["prediction"] = prediction

In [None]:
from sklearn.metrics import confusion_matrix

conf_mat = confusion_matrix(full_ecg_no_undecided_test_df["class_index"].astype(int), full_ecg_no_undecided_test_df["prediction"].astype(int))
print("Confusion matrix:")
print(conf_mat)

def F1_ind(conf_mat, ind):
    return (2 * conf_mat[ind, ind])/(np.sum(conf_mat[ind]) + np.sum(conf_mat[:, ind]))

print(f"Sensitivity: {conf_mat[1, 1]/np.sum(conf_mat[1])}")
print(f"Specificity: {conf_mat[0, 0]/np.sum(conf_mat[0])}")

print(f"Normal F1: {F1_ind(conf_mat, 0)}")
print(f"Noisy F1: {F1_ind(conf_mat, 1)}")

In [None]:
false_positives = full_ecg_no_undecided_test_df[(full_ecg_no_undecided_test_df["class_index"] == 0) & (full_ecg_no_undecided_test_df["prediction"] == 1)]

for _, ecg in feas2_ecg_data[feas2_ecg_data["measID"].isin(false_positives["measID"])]["data"].iteritems():
    plot_ecg_and_reconstruction(ecg, ecg)
    plt.show()