In [102]:
import matplotlib.pyplot as plt
import plotly.graph_objects as go

import numpy as np
import pandas as pd

import sys
import os

import scipy.signal
from ecgdetectors import Detectors

import math
from DataHandlers.DiagEnum import DiagEnum
import DataHandlers.DiagEnum
import DataHandlers.SAFERDataset as SAFERDataset
import DataHandlers.CinC2020Dataset as CinC2020Dataset
import DataHandlers.CinC2020Enums
import importlib
import DataHandlers.CinCDataset as CinCDataset
import DataHandlers.DataAugmentations as DataAugmentations
from multiprocesspandas import applyparallel
importlib.reload(SAFERDataset)
importlib.reload(CinC2020Dataset)

import DataHandlers.DataProcessUtilities
importlib.reload(DataHandlers.DataProcessUtilities)
from DataHandlers.DataProcessUtilities import *
import Utilities.Plotting
importlib.reload(Utilities.Plotting)
from Utilities.Plotting import *

# A fudge because I moved the files
sys.modules["SAFERDataset"] = SAFERDataset
sys.modules["CinC2020Dataset"] = CinC2020Dataset
sys.modules["DiagEnum"] = DataHandlers.DiagEnum
sys.modules["CinC2020Enums"] = DataHandlers.CinC2020Enums
sys.modules["CinCDataset"] = CinCDataset

In [6]:
from scipy.special import softmax

print(softmax(np.array([0.99704283, -0.9721041, 0.37072268])))

[0.59732477 0.08337213 0.3193031 ]


In [2]:
import torch
from torch import nn

enable_cuda = True

if torch.cuda.is_available() and enable_cuda:
    print("Using Cuda")
    device = torch.device("cuda")
else:
    print("Using CPU")
    device = torch.device("cpu")

Using Cuda


### Load SAFER data

In [4]:
feas2_pt_data, feas2_ecg_data = SAFERDataset.load_feas_dataset(2, "dataframe_reload")
feas2_ecg_data["measID"] += 300000
feas2_ecg_data.index = feas2_ecg_data["measID"]

D:\2022_23_DSiromani\Feas2\ECGs/filtered_dataframe_reload.pk


In [5]:
feas2_ecg_data["feas"] = 2

In [129]:
def reduce_normals_all_other_af(pt_data, ecg_data):
    accepted_meas_diags = [DiagEnum.AF, DiagEnum.NoAF, DiagEnum.HeartBlock]
    ecg_data = ecg_data[(ecg_data["measDiag"].isin(accepted_meas_diags)) | (ecg_data["measID"] < 20000) | (ecg_data["not_tagged_ign_wide_qrs"] == 0)]
    pt_data = pt_data[pt_data["ptID"].isin(ecg_data["ptID"])]

    return pt_data, ecg_data

# warning: changing these chunk sizes may reload feas1 data from scratch, which will take ages
chunk_size = 20000
num_chunks = math.ceil(162515 / chunk_size )

def load_feas1_chunk_range(chunk_range=(0, num_chunks)):
    ecg_data = []
    pt_data = []

    for chunk_num in range(chunk_range[0], chunk_range[1]):
        feas1_pt_data, feas1_ecg_data = SAFERDataset.load_feas_dataset(1, f"dataframe_{chunk_num}.pk")

        ecg_data.append(feas1_ecg_data)
        pt_data.append(feas1_pt_data)

    feas1_ecg_data = pd.concat(ecg_data)
    feas1_ecg_data["feas"] = 1
    feas1_ecg_data["rri_len"] = feas1_ecg_data["rri_feature"].map(lambda x: x[x > 0].shape[-1])
    feas1_pt_data = pd.concat(pt_data).drop_duplicates()

    return feas1_ecg_data, feas1_pt_data

In [None]:
feas1_ecg_data, feas1_pt_data = load_feas1_chunk_range((0, num_chunks))

In [134]:
def prepare_safer_data(pt_data, ecg_data):
    if "length" in ecg_data:
        ecg_data = ecg_data[ecg_data["length"] == 9120]

    ecg_data = ecg_data[ecg_data["measDiag"] != DiagEnum.PoorQuality]
    # ecg_data = ecg_data[ecg_data["tag_orig_Poor_Quality"] == 0]

    ecg_data = ecg_data[ecg_data["rri_len"] > 5]


    pt_data.index = pt_data["ptID"]
    ecg_data = SAFERDataset.generate_af_class_labels(ecg_data)
    pt_data = SAFERDataset.add_ecg_class_counts(pt_data, ecg_data)

    return pt_data, ecg_data

In [8]:
# just use feas2
safer_ecg_data = feas2_ecg_data
safer_ecg_data["ffReview_sent"] = -1
safer_ecg_data["ffReview_remain"] = -1
safer_pt_data = feas2_pt_data

safer_pt_data, safer_ecg_data = prepare_safer_data(safer_pt_data, safer_ecg_data)

In [144]:
# Just use feas1 to prepare test and validation datasets (The train is best handled with a DatasetSequenceIterator)
feas1_pt_data, feas1_ecg_data = prepare_safer_data(feas1_pt_data, feas1_ecg_data)
feas1_ecg_data["class_index"].value_counts()

feas1_ecg_data_test = feas1_ecg_data[feas1_ecg_data["ptID"].isin(test_pts["ptID"])]
feas1_ecg_data_val = feas1_ecg_data[feas1_ecg_data["ptID"].isin(val_pts["ptID"])]

print(feas1_ecg_data_test["class_index"].value_counts())
print(feas1_ecg_data_val["class_index"].value_counts())

0    22274
2      377
1      102
Name: class_index, dtype: int64
0    22862
2      452
1      126
Name: class_index, dtype: int64


In [163]:
doc_path = r"C:\Users\daniel\Documents"

feas1_ecg_data_test.to_pickle(os.path.join(doc_path, "feas1_test_27_mar.pk"))
feas1_ecg_data_val.to_pickle(os.path.join(doc_path, "feas1_val_27_mar.pk"))

In [154]:
safer_ecg_data = pd.concat([feas2_ecg_data, feas1_ecg_data])
safer_pt_data = pd.concat([feas2_pt_data, feas1_pt_data])

safer_pt_data, safer_ecg_data = prepare_safer_data(safer_pt_data, safer_ecg_data)

In [9]:
safer_ecg_data.groupby("feas")["class_index"].value_counts()

feas  class_index
2     0              19513
      2                757
      1                 16
Name: class_index, dtype: int64

In [172]:
# Plot a heartrate histogram for AF and not AF
fig, ax = plt.subplots(figsize=(6, 4), dpi=300)
ax.hist(safer_ecg_data["heartrate"][(safer_ecg_data["measDiag"] != DiagEnum.AF) & (safer_ecg_data["feas"] == 1)], alpha=0.7, density=True, label="Normal or Other Rhythm")
ax.hist(safer_ecg_data["heartrate"][(safer_ecg_data["measDiag"] == DiagEnum.AF) & (safer_ecg_data["feas"] == 1)], alpha=0.7, density=True, label="AF")
ax.set_xlabel("Heartrate (bpm)")
ax.set_ylabel("Frequency proportion")
ax.legend()

fig.tight_layout()
fig.show()

In [16]:
# Cut out high and low heartrates - I dont think this makes a difference so havent been doing it mostly
safer_ecg_data = safer_ecg_data[(safer_ecg_data["heartrate"] < 120) & (safer_ecg_data["heartrate"] > 50)]

In [None]:
for _, ecg in safer_ecg_data[safer_ecg_data["feas"] == 1].sample(frac=1).iterrows():
    print(ecg[["measDiag", "class_index", "heartrate", "r_peaks"]])
    plot_ecg(ecg["data"], r_peaks=ecg["r_peaks"], fs=300, n_split=3)
    plt.show()

In [25]:
# Plot the 1 feas2 AF example with high heartrate!

plot_ecg(safer_ecg_data.loc[310209]["data"], r_peaks=safer_ecg_data.loc[310209]["r_peaks"], fs=300, n_split=3, figsize=(6, 5), export_quality=True)
plot_ecg_spectrogram(safer_ecg_data.loc[310209]["data"], fs=300, n_split=3, figsize=(6, 5), export_quality=True, cut_range=(2, 18))
plt.show()

### Load CinC 2020 data

In [9]:
import DataHandlers.CinC2020Dataset as CinC2020Dataset
import importlib
importlib.reload(CinC2020Dataset)

df = CinC2020Dataset.load_dataset(save_name="dataframe_2")

In [10]:
# At the moment we only select data with length which can be truncated to 3000 samples (10s)
def select_length(df):
    df_within_range = df[(df["length"] <= 5000) & (df["length"] >= 3000)].copy()
    df_within_range["data"] = df_within_range["data"].map(lambda x: x[:3000])
    df_within_range["length"] = df_within_range["data"].map(lambda x: x.shape[0])
    return df_within_range

df = select_length(df)

In [11]:
# Plot a heartrate histogram for AF and not AF
fig, ax = plt.subplots(figsize=(6, 4), dpi=300)
ax.hist(df["heartrate"][(df["measDiag"] != DiagEnum.AF)], alpha=0.7, density=True, label="Normal or Other Rhythm")
ax.hist(df["heartrate"][(df["measDiag"] == DiagEnum.AF)], alpha=0.7, density=True, label="AF")
ax.set_xlabel("Heartrate (bpm)")
ax.set_ylabel("Frequency proportion")
ax.legend()

fig.tight_layout()
fig.show()

In [26]:
noise = noise_df.sample()["data"].iloc[0] * np.random.normal(scale=1)

for _, ecg in df.iterrows():
    noise_scale = np.random.normal(scale=0.2)
    noise = noise_df.sample()["data"].iloc[0] * noise_scale
    print(noise_scale)
    plot_ecg(ecg["data"], figsize=(5, 2), export_quality=True)
    plot_ecg(ecg["data"] + noise, figsize=(5, 2), export_quality=True)
    plot_ecg(noise, figsize=(5, 2), export_quality=True)
    plt.show()

-0.04636585375203334
-0.12941265829123266


KeyboardInterrupt: 

In [83]:

for _, ecg in df[df["class_index"] == 0].iterrows():
    plot_ecg(ecg["data"][:1500], figsize=(5, 2.5), export_quality=True)
    plt.show()

KeyboardInterrupt: 

In [10]:
df.groupby("dataset")["class_index"].value_counts()

dataset               class_index
cpsc_2018             2               2047
                      0                984
                      1                903
cpsc_2018_extra       2                364
                      0                350
                      1                113
georgia               2               5257
                      0               3508
                      1                566
ptb-xl                0              10692
                      2               9349
                      1               1514
st_petersburg_incart  0              10955
                      1               1010
                      2                363
Name: class_index, dtype: int64

### Load noise from MIT database

In [16]:
import wfdb
import os
from scipy import signal

noises = ["em", "ma"]
noise_dfs = []
mit_dataset_path = "Datasets/mit-bih-noise-stress-test-database"

f_low = 0.67
f_high = 25

def split_signal(data, split_len):
    data_splits = []
    splits = np.arange(0, data["data"].shape[0], split_len)

    for i, (start, end) in enumerate(zip(splits, splits[1:])):
        data_split = data.copy()
        data_split["data"] = data["data"][start:end]
        data_split["data"] = (data_split["data"] - data_split["data"].mean())/ data_split["data"].std()

        data_split.name = i
        data_splits.append(data_split)

    return data_splits


for n_path in noises:
    rec = wfdb.rdrecord(os.path.join(mit_dataset_path, n_path))
    sig = np.concatenate([rec.p_signal[:, 0], rec.p_signal[:, 1]])

    bandpass = signal.butter(3, [f_low, f_high], 'bandpass', fs=rec.fs, output='sos')
    notch = signal.butter(3, [48, 52], 'bandstop', fs=rec.fs, output='sos')

    sig = filter_and_norm(sig, bandpass)
    sig = filter_and_norm(sig, notch)

    sig = resample(sig, rec.fs, 300)
    sig_series = pd.Series(data={"data": sig, "fs": 300, "noise_type": n_path})

    split_signals = split_signal(sig_series, 3000)
    split_signals = pd.DataFrame(split_signals)

    noise_dfs.append(split_signals)

noise_df = pd.concat(noise_dfs, ignore_index=True)

### Load CinC2017 Dataset

In [56]:
importlib.reload(CinCDataset)
import DataHandlers.DataProcessUtilities
importlib.reload(DataHandlers.DataProcessUtilities)
from DataHandlers.DataProcessUtilities import *

cinc2017_df = CinCDataset.load_cinc_dataset()

100%|██████████| 8528/8528 [00:05<00:00, 1609.09it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  ecg_data["rri_len"] = ecg_data["rri_feature"].map(lambda x: x[x > 0].shape[-1])


In [57]:
cinc2017_df = cinc2017_df[cinc2017_df["length"] == 9000]
cinc2017_df["measDiag"].value_counts()

DiagEnum.NoAF                      3694
DiagEnum.CannotExcludePathology    1649
DiagEnum.AF                         504
DiagEnum.PoorQuality                123
Name: measDiag, dtype: int64

In [58]:
cinc2017_df = cinc2017_df[cinc2017_df["length"] == 9000]
cinc2017_df = cinc2017_df[cinc2017_df["measDiag"] != DiagEnum.PoorQuality]

In [59]:
cinc2017_df["class_index"].value_counts()

0    5847
Name: class_index, dtype: int64

In [182]:
# Plot a heartrate histogram for AF and not AF
fig, ax = plt.subplots(figsize=(6, 4), dpi=300)
ax.hist(cinc2017_df["heartrate"][(cinc2017_df["class_index"] != 1)], alpha=0.7, density=True, label="Normal or Other Rhythm")
ax.hist(cinc2017_df["heartrate"][(cinc2017_df["class_index"] == 1)], alpha=0.7, density=True, label="AF")
ax.set_xlabel("Heartrate (bpm)")
ax.set_ylabel("Frequency proportion")
ax.legend()

fig.tight_layout()
fig.show()

In [81]:
ecgs = cinc2017_df[(cinc2017_df["class_index"] == 2) & (cinc2017_df["heartrate"] > 120)]

for _, ecg in ecgs.iterrows():
    plot_ecg(ecg["data"][:3000], 300, n_split=1, r_peaks=ecg["r_peaks"], figsize=(6, 2.5), export_quality=True)
    plot_ecg_spectrogram(ecg["data"][:3000], 300, n_split=1, cut_range=[2, 18], figsize=(6, 2.5), export_quality=True)
    plot_ecg_poincare(ecg["rri_feature"][:10], 10)# ecg["rri_len"])
    plt.show()

KeyboardInterrupt: 

### Generate dataloaders

In [38]:
mapper = CinC2020Dataset.CinC2020DiagMapper()
num_unique_classes = len(mapper.diag_desc.index)

# Note this only gets used for CinC data - the safer data labels were decided to have different meanings
def class_index_map(diag):
    if diag == DiagEnum.NoAF:
        return 0
    elif diag == DiagEnum.AF:
        return 1
    elif diag == DiagEnum.CannotExcludePathology:
        return 2
    elif diag == DiagEnum.Undecided:
        return 0

In [39]:
cinc2017_df["class_index"] = cinc2017_df["measDiag"].map(class_index_map)

NameError: name 'cinc2017_df' is not defined

In [12]:
# Onehot encoding
from torch.utils.data import Dataset, DataLoader

class Dataset(torch.utils.data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, dataset):
        'Initialization'
        self.dataset = dataset
        self.noise_prob = 0
        self.temp_warp = 0

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset.index)

    def set_noise_prob(self, prob, power_std, noise_df):
        self.noise_prob = prob
        self.noise_power_std = power_std
        self.noise_df = noise_df

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        row = self.dataset.iloc[index]

        data = row["data"]
        rri = row["rri_feature"]
        rri_len = row["rri_len"]

        warp = np.random.binomial(1, self.temp_warp)
        if warp:
            data, r_peaks = DataAugmentations.temporal_warp(data, row["r_peaks_hamilton"])
            rri = get_rri_feature(r_peaks, 20)

        add_noise = np.random.binomial(1, self.noise_prob)
        if add_noise:
            noise = noise_df.sample()["data"].iloc[0] * np.random.normal(scale=self.noise_power_std)
            data += noise

        X = (data, rri, rri_len)
        y = row["class_index"]
        ind = row.name

        return X, y, ind

In [13]:
# For SAFER data
# Split train and test data according to each patient
# Note this function stratifies for AF and non AF!
def generate_patient_splits(pt_data, test_frac, val_frac):
    train_patients = []
    test_patients = []
    val_patients = []

    test_val_frac = test_frac + val_frac
    val_second_frac = val_frac/test_val_frac

    for val, df in pt_data.groupby("noAFRecs"):
        print(f"processing {val}")
        print(f"number of patients {len(df.index)}")



        n = math.floor(len(df.index) * test_val_frac)
        if  test_val_frac > 0:
            res = ((len(df.index) * test_val_frac) - n)/test_val_frac
        else:
            res = 0
        n += np.random.binomial(res, test_val_frac)
        test_val = df.sample(n)

        n = math.floor(len(test_val.index) * val_second_frac)
        if  val_second_frac > 0:
            res = ((len(test_val.index) * val_second_frac) - n)/val_second_frac
        else:
            res = 0
        n += np.random.binomial(res, val_second_frac)
        val = test_val.sample(n)
        val_patients.append(val)

        test_patients.append(test_val[~test_val["ptID"].isin(val["ptID"])])
        train_patients.append(df[~df["ptID"].isin(test_val["ptID"])])

    train_pt_df = pd.concat(train_patients)
    test_pt_df = pd.concat(test_patients)
    val_pt_df = pd.concat(val_patients)

    return train_pt_df, test_pt_df, val_pt_df


def make_SAFER_dataloaders(pt_data, ecg_data, test_frac, val_frac, batch_size=128):
    train_pt_df, test_pt_df, val_pt_df = generate_patient_splits(pt_data, test_frac, val_frac)

    print(f"Test AF: {test_pt_df['noAFRecs'].sum()} Normal: {test_pt_df['noNormalRecs'].sum()} Other: {test_pt_df['noOtherRecs'].sum()}")
    print(f"Train AF: {train_pt_df['noAFRecs'].sum()} Normal: {train_pt_df['noNormalRecs'].sum()} Other: {train_pt_df['noOtherRecs'].sum()}")
    print(f"Val AF: {val_pt_df['noAFRecs'].sum()} Normal: {val_pt_df['noNormalRecs'].sum()} Other: {val_pt_df['noOtherRecs'].sum()}")

    train_dataloader = None
    test_dataloader = None
    val_dataloader = None

    train_dataset = None
    test_dataset = None
    val_dataset = None

    if not train_pt_df.empty:
        # get ECG datasets
        train_dataset = ecg_data[ecg_data["ptID"].isin(train_pt_df["ptID"])]
        # Normalise
        train_dataset["data"] = (train_dataset["data"] - train_dataset["data"].map(lambda x: x.mean()))/train_dataset["data"].map(lambda x: x.std())
        torch_dataset_train = Dataset(train_dataset)
        train_dataloader = DataLoader(torch_dataset_train, batch_size=batch_size, shuffle=True, pin_memory=True)

    if not test_pt_df.empty:
        test_dataset = ecg_data[(ecg_data["ptID"].isin(test_pt_df["ptID"]))]
        test_dataset["data"] = (test_dataset["data"] - test_dataset["data"].map(lambda x: x.mean()))/test_dataset["data"].map(lambda x: x.std())
        torch_dataset_test = Dataset(test_dataset)
        test_dataloader = DataLoader(torch_dataset_test, batch_size=batch_size, shuffle=True, pin_memory=True)

    if not val_pt_df.empty:
        val_dataset = ecg_data[(ecg_data["ptID"].isin(val_pt_df["ptID"]))]
        val_dataset["data"] = (val_dataset["data"] - val_dataset["data"].map(lambda x: x.mean()))/val_dataset["data"].map(lambda x: x.std())
        torch_dataset_val = Dataset(val_dataset)
        val_dataloader = DataLoader(torch_dataset_val, batch_size=batch_size, shuffle=True, pin_memory=True)

    return train_dataloader, test_dataloader, val_dataloader, train_dataset, test_dataset, val_dataset

In [70]:
train_dataloader_safer, test_dataloader_safer, val_dataloader_safer, train_dataset_safer, test_dataset_safer, val_dataset_safer = make_SAFER_dataloaders(safer_pt_data, safer_ecg_data, test_frac=0.15, val_frac=0.15, batch_size=32)

processing 0.0
number of patients 2366
processing 1.0
number of patients 12
processing 2.0
number of patients 11
processing 3.0
number of patients 4
processing 4.0
number of patients 5
processing 5.0
number of patients 3
processing 6.0
number of patients 1
processing 8.0
number of patients 2
processing 9.0
number of patients 2
processing 10.0
number of patients 3
processing 11.0
number of patients 3
processing 18.0
number of patients 2
processing 19.0
number of patients 2
processing 22.0
number of patients 2
processing 26.0
number of patients 1
processing 29.0
number of patients 2
processing 35.0
number of patients 2
processing 39.0
number of patients 1
processing 45.0
number of patients 1
processing 53.0
number of patients 1
processing 62.0
number of patients 1
processing 80.0
number of patients 1
processing 94.0
number of patients 1
Test AF: 155.0 Normal: 24905.0 Other: 853.0
Train AF: 498.0 Normal: 118092.0 Other: 2360.0
Val AF: 176.0 Normal: 25902.0 Other: 713.0


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_dataset["data"] = (train_dataset["data"] - train_dataset["data"].map(lambda x: x.mean()))/train_dataset["data"].map(lambda x: x.std())


NameError: name 'Dataset' is not defined

In [14]:
def get_dataloaders(dataset, batch_size=32):
    torch_dataset = Dataset(dataset)
    dataloader = DataLoader(torch_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    return dataloader

In [15]:
# validate on Feas2 and train/test on feas1
val_dataset_safer = safer_ecg_data[safer_ecg_data["feas"] == 2]
val_dataloader_safer = get_dataloaders(val_dataset_safer)

In [16]:
val_dataset_safer["class_index"].value_counts()

0    19513
2      757
1       16
Name: class_index, dtype: int64

In [17]:
### Make dataloaders for CinC data - separate cpsc as the validation set
from sklearn.model_selection import train_test_split

val_dataset = df[df["dataset"] == "cpsc_2018"]
train_dataset, test_dataset = train_test_split(df[df["dataset"] != "cpsc_2018"], test_size=0.15, stratify=df[df["dataset"] != "cpsc_2018"]["class_index"])
# test_dataset, val_dataset = train_test_split(test_dataset, test_size=0.5, stratify=test_dataset["class_index"])

test_dataset = test_dataset[test_dataset["measDiag"] != DiagEnum.Undecided]  # Should just remove any errors in loading the dataset
val_dataset = val_dataset[val_dataset["measDiag"] != DiagEnum.Undecided]  # Should just remove any errors in loading the dataset

torch_dataset_test = Dataset(test_dataset)
test_dataloader = DataLoader(torch_dataset_test, batch_size=128, shuffle=True, pin_memory=True)

torch_dataset_val = Dataset(val_dataset)
val_dataloader = DataLoader(torch_dataset_val, batch_size=128, shuffle=True, pin_memory=True)

torch_dataset_train = Dataset(train_dataset)
# torch_dataset_train.temp_warp = 0.2
# torch_dataset_train.set_noise_prob(0.1, 0.2, noise_df)
train_dataloader = DataLoader(torch_dataset_train, batch_size=128, shuffle=True, pin_memory=True)

In [90]:
# Set the proportion of AF samples in the test data to that of the train data

val_df_counts = val_dataset["class_index"].value_counts()
train_df_counts = train_dataset["class_index"].value_counts()

train_not_af = train_df_counts.loc[2] + train_df_counts.loc[0]
val_not_af = val_df_counts.loc[2] + val_df_counts.loc[0]

val_af_wanted = int(round((train_df_counts.loc[1]/train_not_af) * val_not_af))

wanted_af_samples = val_dataset[val_dataset["class_index"] == 1].sample(val_af_wanted)
val_dataset = pd.concat([val_dataset[val_dataset["class_index"] != 1], wanted_af_samples])

torch_dataset_val = Dataset(val_dataset)
val_dataloader = DataLoader(torch_dataset_val, batch_size=32, shuffle=True, pin_memory=True)

In [63]:
### CinC2017 data
from sklearn.model_selection import train_test_split

test_size = 0.15
val_size = 0.15

train_dataset_2017, test_val = train_test_split(cinc2017_df.dropna(subset="class_index"), test_size=test_size + val_size, stratify=cinc2017_df["class_index"].dropna())
test_dataset_2017, val_dataset_2017 = train_test_split(test_val, test_size=val_size/(test_size + val_size), stratify=test_val["class_index"])

# test_dataset_2017 = test_dataset_2017[test_dataset_2017["measDiag"] != DiagEnum.Undecided]  # Should just remove any errors in loading the dataset

torch_dataset_test = Dataset(test_dataset_2017)
test_dataloader_2017 = DataLoader(torch_dataset_test, batch_size=32, shuffle=True, pin_memory=True)

torch_dataset_train = Dataset(train_dataset_2017)
train_dataloader_2017 = DataLoader(torch_dataset_train, batch_size=32, shuffle=True, pin_memory=True)

torch_dataset_val = Dataset(val_dataset_2017)
val_dataloader_2017 = DataLoader(torch_dataset_val, batch_size=32, shuffle=True, pin_memory=True)

In [64]:
print(train_dataset_2017["class_index"].value_counts())
print(test_dataset_2017["class_index"].value_counts())
print(val_dataset_2017["class_index"].value_counts())

0    2585
2    1154
1     353
Name: class_index, dtype: int64
0    554
2    247
1     76
Name: class_index, dtype: int64
0    555
2    248
1     75
Name: class_index, dtype: int64


In [184]:
# Save the CinC2017 data splits for consistent results!
train_dataset_2017.to_pickle("TrainedModels/19_May_cinc_2017_train.pk")
test_dataset_2017.to_pickle("TrainedModels/19_May_cinc_2017_test.pk")
val_dataset_2017.to_pickle("TrainedModels/19_May_cinc_2017_val.pk")

In [18]:
train_dataset_2017 = pd.read_pickle("TrainedModels/19_May_cinc_2017_train.pk")
test_dataset_2017 = pd.read_pickle("TrainedModels/19_May_cinc_2017_test.pk")
val_dataset_2017 = pd.read_pickle("TrainedModels/19_May_cinc_2017_val.pk")

train_dataloader_2017 = get_dataloaders(train_dataset_2017, 32)
test_dataloader_2017 = get_dataloaders(test_dataset_2017, 32)
val_dataloader_2017 = get_dataloaders(val_dataset_2017, 32)

In [131]:
### Use whole CinC2017 as a test
dataset_2017 = cinc2017_df[cinc2017_df["measDiag"] != DiagEnum.Undecided].dropna(subset="class_index")

torch_dataset = Dataset(dataset_2017)
dataloader_2017 = DataLoader(torch_dataset, batch_size=32, shuffle=True, pin_memory=True)

In [132]:
dataset_2017["class_index"].value_counts()

0    3694
2    1649
1     504
Name: class_index, dtype: int64

### Use the noise detector to filter the datasets

In [182]:
# Filter noisy things out of SAFER
import Models.NoiseCNN
import importlib

importlib.reload(Models.NoiseCNN)
from Models.NoiseCNN import CNN, hyperparameters

noiseDetector = CNN(**hyperparameters).to(device)
noiseDetector.load_state_dict(torch.load("TrainedModels/CNN_16_may_final_no_undecided.pt", map_location=device))
noiseDetector.eval()

def add_noise_predictions(nd, dataloader, dataset):
    noise_ps = []
    inds = []

    with torch.no_grad():
        for i, (signals, labels, ind) in enumerate(dataloader):
            signal = signals[0].to(device).float()
            noise_prob = nd(torch.unsqueeze(signal, 1)).detach().to("cpu").numpy()

            for i, n in zip(ind, noise_prob):
                if type(i) == str:
                    inds.append(i)
                else:
                    inds.append(i.item())
                noise_ps.append(float(n))

    if dataset is not None:
        dataset["noise_probs"] = pd.Series(data=noise_ps, index=inds)
    else:
        return pd.Series(data=noise_ps, index=inds)

In [20]:
add_noise_predictions(noiseDetector, val_dataloader_safer, val_dataset_safer)
# add_noise_predictions(noiseDetector, test_dataloader_safer, test_dataset_safer)
# add_noise_predictions(noiseDetector, train_dataloader_safer, train_dataset_safer)

In [21]:
# Remove the noisy samples
# train_dataset_safer_clean = train_dataset_safer[train_dataset_safer["noise_probs"] < 0]
# test_dataset_safer_clean = test_dataset_safer[test_dataset_safer["noise_probs"] < 0]
val_dataset_safer_clean = val_dataset_safer[val_dataset_safer["noise_probs"] < 0]

# print(len(train_dataset_safer_clean.index))
# print(len(test_dataset_safer_clean.index))
print(len(val_dataset_safer_clean.index))

# train_dataloader_safer_clean = get_dataloaders(train_dataset_safer_clean)
# test_dataloader_safer_clean = get_dataloaders(test_dataset_safer_clean)
val_dataloader_safer_clean = get_dataloaders(val_dataset_safer_clean)

17861


In [185]:
add_noise_predictions(noiseDetector, train_dataloader_2017, train_dataset_2017)
add_noise_predictions(noiseDetector, test_dataloader_2017, test_dataset_2017)
add_noise_predictions(noiseDetector, val_dataloader_2017, val_dataset_2017)

In [186]:
print(train_dataset_2017["class_index"].value_counts())
print(test_dataset_2017["class_index"].value_counts())
print(val_dataset_2017["class_index"].value_counts())

0    2585
2    1154
1     353
Name: class_index, dtype: int64
0    554
2    247
1     76
Name: class_index, dtype: int64
0    555
2    248
1     75
Name: class_index, dtype: int64


In [187]:
# Remove the noisy samples
thresh = 0.3

train_dataset_2017_clean = train_dataset_2017[train_dataset_2017["noise_probs"] < 0]
test_dataset_2017_clean = test_dataset_2017[test_dataset_2017["noise_probs"] < 0]
val_dataset_2017_clean = val_dataset_2017[val_dataset_2017["noise_probs"] < 0]

print(train_dataset_2017_clean["class_index"].value_counts())
print(test_dataset_2017_clean["class_index"].value_counts())
print(val_dataset_2017_clean["class_index"].value_counts())


train_dataloader_2017_clean = get_dataloaders(train_dataset_2017_clean)
test_dataloader_2017_clean = get_dataloaders(test_dataset_2017_clean)
val_dataset_2017_clean = val_dataset_2017[val_dataset_2017["noise_probs"] < 0]

0    1955
2     857
1     255
Name: class_index, dtype: int64
0    405
2    172
1     50
Name: class_index, dtype: int64
0    392
2    178
1     50
Name: class_index, dtype: int64


In [130]:
import threading

class DatasetSequenceIterator:

    def __init__(self, data_loading_functions, batch_sizes, filter=lambda x:x):
        self.dl_functions = data_loading_functions

        self.dataset = None
        self.next_dataset = None

        self.dataloader_iterator = None
        self.next_dataloader_iterator = None

        self.next_dataset_loaded = False
        self.dataloader_thread = None

        self.filter = filter

        self.batch_sizes = batch_sizes
        self.dl_index = 0

    def __iter__(self):
        self.dl_index = -1
        self.dataloader_thread = threading.Thread(target=self.load_next_dataset)
        self.dataloader_thread.start()
        self.dataloader_thread.join()
        self.swap_to_next_dataset()
        self.dl_index += 1
        self.dataloader_thread = threading.Thread(target=self.load_next_dataset)
        self.dataloader_thread.start()
        print(self.dl_index)
        return self

    def __len__(self):
        # TODO make this return the right value
        return 100

    def swap_to_next_dataset(self):
        self.dataset = self.next_dataset
        self.dataloader_iterator = self.next_dataloader_iterator
        self.next_dataset_loaded = False

    def load_next_dataset(self):
        if self.dl_index + 1 < len(self.dl_functions):
            print(f"Loading dataset {self.dl_index + 1}")
            self.next_dataset = self.dl_functions[self.dl_index + 1]()
            self.next_dataset = self.filter(self.next_dataset)

            torch_dataset = Dataset(self.next_dataset)
            self.next_dataloader_iterator = iter(DataLoader(torch_dataset, batch_size=self.batch_sizes[self.dl_index], shuffle=True, pin_memory=True))
            self.next_dataset_loaded = True
        else:
            print("Finished loading all datasets")
            self.next_dataset_loaded = False
            return None



    def __next__(self):
        try:
            ret = next(self.dataloader_iterator)
        except StopIteration:
            print("stop_iteration")
            if self.dl_index >= len(self.dl_functions):
                # We have gone through all the datasets
                print("Completed all datasets")
                raise StopIteration
            else:

                if not self.next_dataset_loaded:
                    print("waiting_for_next_dataset")
                    self.dataloader_thread.join()

                self.swap_to_next_dataset()
                self.dl_index += 1
                self.dataloader_thread = threading.Thread(target=self.load_next_dataset)
                self.dataloader_thread.start()
                ret = next(self.dataloader_iterator)

        return ret

In [180]:
# Testing the DatasetSequenceIterator by dividing feas1 into two parts

def load_feas1_first_half():
    ecg_data, pt_data = load_feas1_chunk_range((0, 1))
    return prepare_safer_data(pt_data, ecg_data)[1]

def load_feas1_second_half():
    ecg_data, pt_data = load_feas1_chunk_range((4, 5))
    return prepare_safer_data(pt_data, ecg_data)[1]

def load_feas1_nth_chuck(n):
    ecg_data, pt_data = load_feas1_chunk_range((n, n+1))
    return prepare_safer_data(pt_data, ecg_data)[1]

loading_functions = [lambda n=n: load_feas1_nth_chuck(n) for n in range(num_chunks)]

feas1_dataloader_entire = DatasetSequenceIterator(loading_functions, [128 for n in range(num_chunks)])

In [None]:
num_ecgs = 0
for i, (signals, labels, _) in enumerate(feas1_dataloader):
    signal = signals[0].to(device).float()
    rris = signals[1].to(device).float()
    rri_len = signals[2].to(device).float()

    num_ecgs += signal.shape[0]

print(num_ecgs)

In [None]:
feas1_noise_predictions = add_noise_predictions(noiseDetector, feas1_dataloader_entire, None)

In [None]:
print(f"number of noisy ECGs: {(feas1_noise_predictions > 0).sum()}")
feas1_path = r"D:\2022_23_DSiromani\Feas1"
feas1_noise_predictions.to_pickle(os.path.join(feas1_path, "ECGs/feas1_noise_predictions.pk"))

In [175]:
feas1_path = r"D:\2022_23_DSiromani\Feas1"
feas1_noise_predictions = pd.read_pickle(os.path.join(feas1_path, "ECGs/feas1_noise_predictions.pk"))

In [176]:
pt_data = SAFERDataset.load_pt_dataset(1)
ecg_data = SAFERDataset.load_ecg_csv(1, pt_data, ecg_range=None, ecg_meas_diag=None, feas2_offset=10000, feas2_ecg_offset=200000)

ecg_data["feas"] = 1
ecg_data["length"] = 9120
ecg_data["rri_len"] = 20

pt_data, ecg_data = prepare_safer_data(pt_data, ecg_data)

# train_pts, test_pts, val_pts = generate_patient_splits(pt_data, 0.15, 0.15)

In [145]:
zenicor_conf_mat = confusion_matrix(feas1_ecg_data_test["class_index"], feas1_ecg_data_test["poss_AF_tag"])
print_results(zenicor_conf_mat)

Confusion matrix:
[[18614   496     0]
 [   21    77     0]
 [  260   206     0]]
Sensitivity: 0.786
Specificity: 0.964

Normal F1: 0.980
AF F1: 0.176
Other F1: 0.000


In [189]:
ecg_data["noise_prediction"] = feas1_noise_predictions
print(ecg_data[ecg_data["noise_prediction"] < 0]["class_index"].value_counts())
print(ecg_data["class_index"].value_counts())

0    125865
2      2554
1       529
Name: class_index, dtype: int64
0    149586
2      3120
1       745
Name: class_index, dtype: int64


In [177]:
for pts in [train_pts, test_pts, val_pts]:
    print(pts["noNormalRecs"].sum())
    print(pts["noAFRecs"].sum())
    print(pts["noOtherRecs"].sum())
    print("")

104364.0
515.0
2282.0

22294.0
102.0
377.0

22928.0
128.0
461.0



In [163]:
feas1_path = r"D:\2022_23_DSiromani\Feas1"
train_pts.to_pickle(os.path.join(feas1_path, "all_feas1_train_pts.pk"))
test_pts.to_pickle(os.path.join(feas1_path, "all_feas1_test_pts.pk"))
val_pts.to_pickle(os.path.join(feas1_path, "all_feas1_val_pts.pk"))

In [133]:
feas1_path = r"D:\2022_23_DSiromani\Feas1"
train_pts = pd.read_pickle(os.path.join(feas1_path, "ECGs/all_feas1_train_pts.pk"))
test_pts = pd.read_pickle(os.path.join(feas1_path, "ECGs/all_feas1_test_pts.pk"))
val_pts = pd.read_pickle(os.path.join(feas1_path, "ECGs/all_feas1_val_pts.pk"))

In [137]:
feas1_ecg_data_test = pd.read_pickle(os.path.join(feas1_path, "ECGs/feas1_test_26_mar.pk"))
feas1_ecg_data_test = feas1_ecg_data_test[feas1_ecg_data_test["rri_len"] > 5]

feas1_ecg_data_val = pd.read_pickle(os.path.join(feas1_path, "ECGs/feas1_val_26_mar.pk"))
feas1_ecg_data_val = feas1_ecg_data_val[feas1_ecg_data_val["rri_len"] > 5]

In [145]:
feas1_ecg_data_test["noise_prediction"] = feas1_noise_predictions
feas1_ecg_data_test_clean = feas1_ecg_data_test[feas1_ecg_data_test["noise_prediction"] < 0]

feas1_ecg_data_val["noise_prediction"] = feas1_noise_predictions
feas1_ecg_data_val_clean = feas1_ecg_data_val[feas1_ecg_data_val["noise_prediction"] < 0]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  feas1_ecg_data_test["noise_prediction"] = feas1_noise_predictions
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  feas1_ecg_data_val["noise_prediction"] = feas1_noise_predictions


In [149]:
feas1_ecg_data_val_clean["class_index"].value_counts()

0    18534
2      364
1       80
Name: class_index, dtype: int64

In [179]:
ecg_data[(feas1_noise_predictions[ecg_data.index] < 0)]

IndexingError: Unalignable boolean Series provided as indexer (index of the boolean Series and of the indexed object do not match).

In [150]:
# Create some a filter function to select data from each partition
def filter_train_pts(ecg_data):
    print(f"filtering {(feas1_noise_predictions[ecg_data.index] > 0).sum()} ECGs out")
    return ecg_data[(ecg_data["ptID"].isin(train_pts["ptID"])) & (feas1_noise_predictions[ecg_data.index] < 0)]

def filter_test_pts(ecg_data):
    return ecg_data[ecg_data["ptID"].isin(test_pts["ptID"]) & (feas1_noise_predictions[ecg_data.index] < 0)]

def filter_val_pts(ecg_data):
    return ecg_data[ecg_data["ptID"].isin(test_pts["ptID"]) & (feas1_noise_predictions[ecg_data.index] < 0)]

feas1_train_dataloader = DatasetSequenceIterator(loading_functions, [64 for n in range(num_chunks)], filter=filter_train_pts)
feas1_test_dataloader = get_dataloaders(feas1_ecg_data_test)
feas1_val_dataloader = get_dataloaders(feas1_ecg_data_val)

In [33]:
feas1_test_dataloader_clean = get_dataloaders(feas1_ecg_data_test_clean)
feas1_val_dataloader_clean = get_dataloaders(feas1_ecg_data_val_clean)

NameError: name 'feas1_ecg_data_test_clean' is not defined

In [223]:
print(feas1_ecg_data_test_clean["class_index"].value_counts())
print(feas1_ecg_data_val_clean["class_index"].value_counts())

print("  ")

print(feas1_ecg_data_test["class_index"].value_counts())
print(feas1_ecg_data_val["class_index"].value_counts())

0    19247
2      325
1       60
Name: class_index, dtype: int64
0    19151
2      391
1       90
Name: class_index, dtype: int64
  
0    22274
2      377
1      102
Name: class_index, dtype: int64
0    22862
2      452
1      126
Name: class_index, dtype: int64


### Prepare for training

In [138]:
del model
del feas1_train_dataloader
del feas1_test_dataloader
torch.cuda.empty_cache()

import gc
gc.collect()

5522

In [279]:
import Models.SpectrogramTransformer
importlib.reload(Models.SpectrogramTransformer)
# from Models.SpectrogramTransformer import TransformerModel
import Models.SpectrogramTransformerAttentionPooling
importlib.reload(Models.SpectrogramTransformerAttentionPooling)
from Models.SpectrogramTransformerAttentionPooling import TransformerModel

In [197]:
from torch.optim.lr_scheduler import StepLR, LambdaLR, SequentialLR

In [283]:
n_head = 4
n_fft = 128
embed_dim = 128 # int(n_fft/2)
n_inp_rri = 64

model = TransformerModel(3, embed_dim, n_head, 512, 6, n_fft, n_inp_rri, device=device).to(device)

(2, 18)


In [281]:
class focal_loss(nn.Module):

    def __init__(self, weights, gamma=2, label_smoothing=0):
        super(focal_loss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss(reduction="none", label_smoothing=label_smoothing)
        self.weights = weights
        self.gamma = gamma

    def forward(self, pred, targets):
        ce = self.ce_loss(pred, targets)
        pt = torch.exp(-ce)

        loss_sum = torch.sum(((1-pt) ** self.gamma) * ce * self.weights[targets])
        norm_factor = torch.sum(self.weights[targets])
        return loss_sum/norm_factor

In [None]:
class_counts = torch.tensor(train_dataset["class_index"].value_counts().sort_index().values.astype(np.float32))
class_weights = (1/class_counts)
class_weights /= torch.sum(class_weights)
print(class_weights)

loss_func = focal_loss(class_weights, 2) # nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1) # focal_loss(class_weights, 2, 0.05) #
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)

number_warmup_batches = 600
def warmup(current_step: int):
    if current_step < number_warmup_batches:
        # print(current_step / number_warmup_batches ** 1.5)
        return current_step / number_warmup_batches ** 1.5
    else:
        # print(1/math.sqrt(current_step))
        return 1/math.sqrt(current_step)  # 1 / (10 ** (float(number_warmup_epochs - current_step)))

scheduler = LambdaLR(optimizer, lr_lambda=warmup)
# scheduler = SequentialLR(optimizer, [warmup_scheduler, scheduler], [number_warmup_epochs])

In [190]:
# Remake scheduler before retraining on SAFER

"""
class_counts = torch.tensor(train_dataset_safer_clean["class_index"].value_counts().sort_index().values.astype(np.float32))
"""

"""
# just approximate weights using feas2 rather than computing for feas 1 - these might be fundamentally different because in feas2 the cardiologist stopped labelling after the first AF from a patient therefore fewer AF.
class_counts = torch.tensor(val_dataset_safer["class_index"].value_counts().sort_index().values.astype(np.float32))
"""

"""
# Use all of feas1 to compute the class counts - precomputed values for next time: tensor([0.0043, 0.7924, 0.2033])
class_counts = torch.tensor(feas1_ecg_data["class_index"].value_counts().sort_index().values.astype(np.float32))

class_weights = (1/class_counts)
class_weights /= torch.sum(class_weights)
"""

class_counts = torch.tensor(ecg_data[ecg_data["noise_prediction"] < 0]["class_index"].value_counts().sort_index().values.astype(np.float32))

class_weights = (1/class_counts)
class_weights /= torch.sum(class_weights)

# class_weights = torch.tensor([0.0043, 0.7924, 0.2033])

print(class_weights)

loss_func = focal_loss(class_weights, gamma=2, label_smoothing=0)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)

number_warmup_batches = 600
def warmup(current_step: int):
    if current_step < number_warmup_batches:
        # print(current_step / number_warmup_batches ** 1.5)
        return current_step / number_warmup_batches ** 1.5
    else:
        # print(1/math.sqrt(current_step))
        return 1/math.sqrt(current_step)  # 1 / (10 ** (float(number_warmup_epochs - current_step)))

scheduler = LambdaLR(optimizer, lr_lambda=warmup)

tensor([0.0035, 0.8255, 0.1710])


In [175]:
# Remake scheduler before retraining on CinC2017

class_counts = torch.tensor(train_dataset_2017["class_index"].value_counts().sort_index().values.astype(np.float32))
class_weights = (1/class_counts)
class_weights /= torch.sum(class_weights)
print(class_weights)

loss_func = nn.CrossEntropyLoss(weight=class_weights) # multiclass_cross_entropy_loss

optimizer = torch.optim.Adam(model.parameters(), lr=0.00004)

number_warmup_batches = 600
def warmup(current_step: int):
    if current_step < number_warmup_batches:
        # print(current_step / number_warmup_batches ** 1.5)
        return current_step / number_warmup_batches ** 1.5
    else:
        # print(1/math.sqrt(current_step))
        return 1/math.sqrt(current_step)  # 1 / (10 ** (float(number_warmup_epochs - current_step)))

scheduler = LambdaLR(optimizer, lr_lambda=warmup)

tensor([0.0947, 0.6933, 0.2121])


In [107]:
# Train the model I stole

import OtherModels.Prna.physionet2020_submission.model
importlib.reload(OtherModels.Prna.physionet2020_submission.model)
from OtherModels.Prna.physionet2020_submission.model import CTN
import OtherModels.Prna.physionet2020_submission.optimizer
importlib.reload(OtherModels.Prna.physionet2020_submission.optimizer)
from OtherModels.Prna.physionet2020_submission.optimizer import NoamOpt

# Train prna's transformer
n_head = 8
n_fft = 128
embed_dim = 128 # int(n_fft/2)
n_inp_rri = 64

class_counts = torch.tensor(train_dataset["class_index"].value_counts().sort_index().values.astype(np.float32))
class_weights = (1/class_counts)
class_weights /= torch.sum(class_weights)
print(class_weights)

model = CTN(256, n_head, 2048, 4, 0.1, 64, 0, 0, 3).to(device)

# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

# optimizer = NoamOpt(256, 1, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
loss_func = nn.CrossEntropyLoss(weight=class_weights)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

number_warmup_batches = 2
def warmup(current_step: int):
    return 1 / (10 ** (float(number_warmup_batches - current_step)))
warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup)

scheduler = SequentialLR(optimizer, [warmup_scheduler, scheduler], [number_warmup_batches])

tensor([0.1427, 0.7559, 0.1014])


In [181]:
from torch.profiler import profile, tensorboard_trace_handler
from tqdm import tqdm

import copy
model = model.to(device)
model.fix_transformer_params(fix_spec=False, fix_rri=False)
num_epochs = 40

def train(model, train_dataloader, test_dataloader):
    best_test_loss = 100
    best_epoch = -1
    best_model = copy.deepcopy(model).cpu()

    losses = []


    for epoch in range(num_epochs):
        total_loss = 0
        print(f"starting epoch {epoch} ...")
        # Train
        num_batches = 0
        model.train()
        for i, (signals, labels, _) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            signal = signals[0].to(device).float()
            rris = signals[1].to(device).float()
            rri_len = signals[2].to(device).float()

            if torch.any(torch.isnan(signal)):
                print("Signals are nan")
                continue

            if torch.any(torch.isnan(rris)):
                print("Signals are nan")
                continue

            labels = labels.long()
            optimizer.zero_grad()
            output = model(signal, rris, rri_len).to("cpu")

            if torch.any(torch.isnan(output)):
                print(signal)
                print(rris)
                print(rri_len)
                print(output)
                raise ValueError

            loss = loss_func(output, labels)
            if torch.isnan(loss):
                raise ValueError
            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2)
            optimizer.step()
            scheduler.step()
            num_batches += 1
            total_loss += float(loss)

        print(num_batches)

        print(f"Epoch {epoch} finished with average loss {total_loss/num_batches}")
        # writer.add_scalar("Loss/train", total_loss/num_batches, epoch)
        print("Testing ...")
        # Test
        num_test_batches = 0
        test_loss = 0
        with torch.no_grad():
            model.eval()
            for i, (signals, labels, _) in enumerate(test_dataloader):
                signal = signals[0].to(device).float()
                rris = signals[1].to(device).float()
                rri_len = signals[2].to(device).float()

                if torch.any(torch.isnan(signal)):
                    print("Signals are nan")
                    continue

                labels = labels.long()
                output = model(signal, rris, rri_len).to("cpu")
                loss = loss_func(output, labels)
                test_loss += float(loss)
                num_test_batches += 1

        print(f"Average test loss: {test_loss/num_test_batches}")
        losses.append([total_loss/num_batches, test_loss/num_test_batches])
        # writer.add_scalar("Loss/test", test_loss/num_t est_batches, epoch)

        if test_loss/num_test_batches < best_test_loss:
            best_model = copy.deepcopy(model).cpu()
            best_test_loss = test_loss/num_test_batches
            best_epoch = epoch
        else:
            if best_epoch + 5 <= epoch:
                return best_model, losses

    return best_model, losses

model, losses = train(model, val_dataloader_safer_clean, val_dataset_safer_clean)
model = model.to(device)

In [276]:
losses = np.load("TrainedModels/Transformer_23_May_cinc_train_attention_pooling_no_augmentation_smoothing_training_curve.npy")

# "C:\Users\daniel\Documents\CambridgeSoftwareProjects\ecg-signal-quality\TrainedModels\Transformer_23_May_feas1_train_attention_pooling_augmentation_smoothing_retrain.npy"

In [277]:
# plot the training curve (1 axis only)
fig, ax = plt.subplots(figsize=(6, 4))

train_l = ax.plot([l[0] for l in losses], label="training loss")
val_l = ax.plot([l[1] for l in losses], label="validation loss", color="#ff7f0e")

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# ax.set_ylim(bottom=0)
ax.set_xlim(left=0)

ax.set_xlabel("Epoch number")
ax.set_ylabel("Loss")

ax.legend()

fig.tight_layout()
plt.show()

In [181]:
losses_np = np.array(losses)
np.save("TrainedModels/Transformer_24_May_cinc_2017_train_attention_pooling_augmentation_smoothing", losses_np)

In [179]:
# Save a model
torch.save(model.state_dict(), "TrainedModels/Transformer_24_May_cinc_2017_train_attention_pooling_augmentation_smoothing.pt")

# train_dataset_safer.to_pickle("TrainedModels/Transformer_15_Mar_train.pk")
# test_dataset_safer.to_pickle("TrainedModels/Transformer_15_Mar_test.pk")
# val_dataset_safer.to_pickle("TrainedModels/Transformer_15_Mar_val.pk")
# train_pt_df.to_pickle("TrainedModels/Transformer_spectrogram_small_fft_cut_all_safer_trained_average_warped_train.pk")
# val_pt_df.to_pickle("TrainedModels/Transformer_spectrogram_small_fft_cut_all_safer_trained_average_warped_val.pk")
# test_pt_df.to_pickle("TrainedModels/Transformer_spectrogram_small_fft_cut_all_safer_trained_average_warped_test.pk")

In [122]:
train_dataset.to_pickle("TrainedModels/Transformer_13_May_cinc_trained_initial_train.pk")
test_dataset.to_pickle("TrainedModels/Transformer_13_May_cinc_trained_initial_test.pk")
val_dataset.to_pickle("TrainedModels/Transformer_13_May_cinc_trained_initial_val.pk")

In [5]:
train_dataset_safer = pd.read_pickle("TrainedModels/Transformer_13_Mar_train.pk")
test_dataset_safer = pd.read_pickle("TrainedModels/Transformer_13_Mar_test.pk")
val_dataset_safer = pd.read_pickle("TrainedModels/Transformer_13_Mar_val.pk")

In [13]:
train_dataloader_safer = get_dataloaders(train_dataset_safer)
test_dataloader_safer = get_dataloaders(test_dataset_safer)
val_dataloader_safer = get_dataloaders(val_dataset_safer)

In [144]:
train_dataset_2017.to_pickle("TrainedModels/Transformer_13_May_cinc_2017_trained_train.pk")
test_dataset_2017.to_pickle("TrainedModels/Transformer_13_May_cinc_2017_trained_test.pk")

In [34]:
# Set this for safer cross validation later
cinc_model_path = "TrainedModels/Transformer_15_Mar_cinc_trained_noise_augmentation.pt"

In [284]:
# Load a model
# model = TransformerModel(2, embed_dim, n_head, 1024, 4, 47, n_fft).to(device)
model.load_state_dict(torch.load("TrainedModels/Transformer_20_May_cinc_train_attention_pooling_augmentation_smoothing.pt", map_location=device))

# train_pt_df = pd.read_pickle("TrainedModels/Transformer_spectrogram_small_fft_cut_all_safer_trained_average_warped_train.pk")
# val_pt_df = pd.read_pickle("TrainedModels/Transformer_spectrogram_small_fft_cut_all_safer_trained_average_warped_val.pk")
# test_pt_df = pd.read_pickle("TrainedModels/Transformer_spectrogram_small_fft_cut_all_safer_trained_average_warped_test.pk")


<All keys matched successfully>

In [None]:
# Load dataset
train_pt_df = pd.read_pickle("TrainedModels/Transformer_13_May_cinc_2017_trained_train.pk")

In [None]:
# Tonights training schedule

print("Stage 0: Training final model with noisy samples removed")
try:
    raise Exception
    feas1_noise_predictions = add_noise_predictions(noiseDetector, feas1_dataloader_entire, None)

    print(f"number of noisy ECGs: {(feas1_noise_predictions > 0).sum()}")
    feas1_path = r"D:\2022_23_DSiromani\Feas1"
    feas1_noise_predictions.to_pickle(os.path.join(feas1_path, "ECGs/feas1_noise_predictions.pk"))

    ecg_data["noise_prediction"] = feas1_noise_predictions
    print(ecg_data[ecg_data["noise_prediction"] < 0]["class_index"].value_counts())
    print(ecg_data["class_index"].value_counts())

    feas1_ecg_data_test["noise_prediction"] = feas1_noise_predictions
    feas1_ecg_data_test_clean = feas1_ecg_data_test[feas1_ecg_data_test["noise_prediction"] < 0]

    feas1_ecg_data_val["noise_prediction"] = feas1_noise_predictions
    feas1_ecg_data_val_clean = feas1_ecg_data_val[feas1_ecg_data_val["noise_prediction"] < 0]

    # Create some a filter function to select data from each partition
    def filter_train_pts(ecg_data):
        print(f"filtering {(feas1_noise_predictions[ecg_data.index] > 0).sum()} ECGs out")
        return ecg_data[(ecg_data["ptID"].isin(train_pts["ptID"])) & (feas1_noise_predictions[ecg_data.index] < 0)]

    def filter_test_pts(ecg_data):
        return ecg_data[ecg_data["ptID"].isin(test_pts["ptID"]) & (feas1_noise_predictions[ecg_data.index] < 0)]

    def filter_val_pts(ecg_data):
        return ecg_data[ecg_data["ptID"].isin(test_pts["ptID"]) & (feas1_noise_predictions[ecg_data.index] < 0)]

    feas1_train_dataloader = DatasetSequenceIterator(loading_functions, [64 for n in range(num_chunks)], filter=filter_train_pts)
    feas1_test_dataloader = get_dataloaders(feas1_ecg_data_test)
    feas1_val_dataloader = get_dataloaders(feas1_ecg_data_val)

    feas1_test_dataloader_clean = get_dataloaders(feas1_ecg_data_test_clean)
    feas1_val_dataloader_clean = get_dataloaders(feas1_ecg_data_val_clean)

except Exception as e:
    print("Error occured in stage 0")
    print(e)

print("Stage 1: Training final model with noisy samples removed")
try:
    raise Exception
    model.load_state_dict(torch.load("TrainedModels/Transformer_20_May_cinc_train_attention_pooling_augmentation_smoothing.pt", map_location=device))
    model, losses = train(model, feas1_train_dataloader, feas1_test_dataloader_clean)

    losses_np = np.array(losses)
    np.save("TrainedModels/Transformer_27_May_feas1_train_attention_pooling_augmentation_smoothing_training_curve_no_noisy_nk_beats", losses_np)

    torch.save(model.state_dict(), "TrainedModels/Transformer_27_May_feas1_train_attention_pooling_augmentation_smoothing_no_noisy_nk_beats.pt")
except Exception as e:
    print("Error occured in stage 1")
    print(e)

print("Stage 2: Training final model with noisy samples included")
try:
    import Models.SpectrogramTransformerAttentionPooling
    importlib.reload(Models.SpectrogramTransformerAttentionPooling)
    from Models.SpectrogramTransformerAttentionPooling import TransformerModel

    model = TransformerModel(3, embed_dim, n_head, 512, 6, n_fft, n_inp_rri, device=device).to(device)
    model.load_state_dict(torch.load("TrainedModels/Transformer_20_May_cinc_train_attention_pooling_augmentation_smoothing.pt", map_location=device))

    # recreate the filters without removing noise
    def filter_train_pts(ecg_data):
        print(f"filtering {(feas1_noise_predictions[ecg_data.index] > 0).sum()} ECGs out")
        return ecg_data[(ecg_data["ptID"].isin(train_pts["ptID"]))]

    def filter_test_pts(ecg_data):
        return ecg_data[ecg_data["ptID"].isin(test_pts["ptID"])]

    def filter_val_pts(ecg_data):
        return ecg_data[ecg_data["ptID"].isin(test_pts["ptID"])]

    feas1_train_dataloader = DatasetSequenceIterator(loading_functions, [64 for n in range(num_chunks)], filter=filter_train_pts)

    class_counts = torch.tensor(ecg_data["class_index"].value_counts().sort_index().values.astype(np.float32))
    class_weights = (1/class_counts)
    class_weights /= torch.sum(class_weights)

    loss_func = focal_loss(class_weights, gamma=2, label_smoothing=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
    number_warmup_batches = 600
    scheduler = LambdaLR(optimizer, lr_lambda=warmup)

    model, losses = train(model, feas1_train_dataloader, feas1_test_dataloader)

    losses_np = np.array(losses)
    np.save("TrainedModels/Transformer_27_May_feas1_train_attention_pooling_augmentation_smoothing_training_curve_nk_beats_retrain", losses_np)

    torch.save(model.state_dict(), "TrainedModels/Transformer_27_May_feas1_train_attention_pooling_augmentation_smoothing_nk_beats_retrain.pt")
except Exception as e:
    print("Error occured in stage 2")
    print(e)



print("Stage 3: Training the initial model with noisy samples included")
try:
    import Models.SpectrogramTransformer
    importlib.reload(Models.SpectrogramTransformer)
    from Models.SpectrogramTransformer import TransformerModel

    model = TransformerModel(3, embed_dim, n_head, 512, 6, n_fft, n_inp_rri, device=device, enable_rri=False).to(device)
    model.load_state_dict(torch.load("TrainedModels/Transformer_12_May_cinc_trained_initial.pt", map_location=device))

    class_counts = torch.tensor(ecg_data["class_index"].value_counts().sort_index().values.astype(np.float32))
    class_weights = (1/class_counts)
    class_weights /= torch.sum(class_weights)

    loss_func = focal_loss(class_weights, gamma=2, label_smoothing=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
    number_warmup_batches = 600
    scheduler = LambdaLR(optimizer, lr_lambda=warmup)

    model, losses = train(model, feas1_train_dataloader, feas1_test_dataloader)

    losses_np = np.array(losses)
    np.save("TrainedModels/Transformer_27_May_feas1_train_initial_training_curve_nk_beats", losses_np)

    torch.save(model.state_dict(), "TrainedModels/Transformer_27_May_feas1_train_initial_nk_beats.pt")
except Exception as e:
    print("Error occured in stage 3")
    print(e)

### Model testing

In [289]:
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, multilabel_confusion_matrix

def get_predictions(model, dataloader, dataset):

    attentions = []

    """
    def hook(module, x, y):
        for a in y[1]:
            attentions.append(a.detach().cpu().numpy())

    attention_hook = model.attention_pooling.attn.register_forward_hook(hook)
    """

    model.eval()

    true_labels = []
    predictions = []

    outputs = []
    inds = []

    with torch.no_grad():
        for i, (signals, labels, ind) in enumerate(dataloader):
            signal = signals[0].to(device).float()
            rris = signals[1].to(device).float()
            rri_len = signals[2].to(device).float()

            labels = labels.long().detach().numpy()
            true_labels.append(labels)

            output = model(signal, rris, rri_len).detach().to("cpu").numpy() # rris).detach().to("cpu").numpy()

            prediction = output # np.argmax(output, axis=-1)
            predictions.append(prediction)

            for i, o in zip(ind, output):
                outputs.append(o)
                if isinstance(i, str):
                    inds.append(i)
                else:
                    inds.append(i.item())

    dataset["prediction"] = pd.Series(data=outputs, index=inds)
    # dataset["attention"] = pd.Series(data=attentions, index=inds)

    predictions = np.concatenate(predictions)
    true_labels = np.concatenate(true_labels)

    # attention_hook.remove()

    return predictions, true_labels

predictions, true_labels = get_predictions(model, feas1_val_dataloader, feas1_ecg_data_val)
conf_mat = confusion_matrix(true_labels, np.argmax(predictions, axis=1))

In [290]:
feas1_ecg_data_test["noise_probs"] = feas1_ecg_data_test["noise_prediction"]
feas1_ecg_data_val["noise_probs"] = feas1_ecg_data_val["noise_prediction"]

def get_noise_free_conf_mat(dataset):
   return confusion_matrix(dataset[dataset["noise_probs"] < 0]["class_index"], dataset[dataset["noise_probs"] < 0]["prediction"].map(np.argmax))

noise_free_conf_mat = get_noise_free_conf_mat(feas1_ecg_data_val)

In [291]:
def F1_ind(conf_mat, ind):
    return (2 * conf_mat[ind, ind])/(np.sum(conf_mat[ind]) + np.sum(conf_mat[:, ind]))

def print_results(conf_mat):
    print("Confusion matrix:")
    print(conf_mat)

    print(f"Sensitivity: {conf_mat[1, 1]/np.sum(conf_mat[1]):0.3f}")
    print(f"Specificity: {(conf_mat[0, 0] + conf_mat[0, 2] + conf_mat[2, 0] + conf_mat[2, 2])/(np.sum(conf_mat[0]) + np.sum(conf_mat[2])):0.3f}")
    print("")

    print(f"Normal F1: {F1_ind(conf_mat, 0):0.3f}")
    print(f"AF F1: {F1_ind(conf_mat, 1):0.3f}")
    print(f"Other F1: {F1_ind(conf_mat, 2):0.3f}")

    print()

print_results(conf_mat)

Confusion matrix:
[[10169   298 12395]
 [   65    61     0]
 [  357    56    39]]
Sensitivity: 0.484
Specificity: 0.985

Normal F1: 0.608
AF F1: 0.226
Other F1: 0.006



In [292]:
# Print noise free conf mats
print_results(noise_free_conf_mat)

Confusion matrix:
[[ 7545   112 11494]
 [   56    34     0]
 [  305    48    38]]
Sensitivity: 0.378
Specificity: 0.992

Normal F1: 0.558
AF F1: 0.239
Other F1: 0.006



In [197]:
misclassified_inds = feas1_ecg_data_test[feas1_ecg_data_test["prediction"].map(np.argmax) != feas1_ecg_data_test["class_index"]].index.values

In [212]:
feas1_test_interesting = feas1_ecg_data_test[feas1_ecg_data_test["class_index"] != 0]
feas1_test_interesting_dataloader = get_dataloaders(feas1_test_interesting, 64)

In [204]:
from sklearn.metrics import precision_recall_curve
from scipy.special import softmax

fig, ax = plt.subplots(figsize=(6, 4), dpi=250)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax.set_xlabel("Recall")
ax.set_ylabel("Precision")

labels = ["Normal", "AF", "Other"]

for i in range(3):
    p, r, d = precision_recall_curve((feas1_ecg_data_val_clean["class_index"] == i),  feas1_ecg_data_val_clean["prediction"].map(lambda x: softmax(x)[i]))
    ax.plot(r, p, label=labels[i])
    # plt.xlim((0, 1.1))
    # plt.ylim((0, 1.1))

    # closest_point_to_0_final = np.argmin(np.abs(d))
    # ax.plot(r[closest_point_to_0_final], p[closest_point_to_0_final], "o", color="#ff7f0e", label=r"$p(AF) = 0.5$")

ax.legend()
plt.show()
# fig.savefig("FinalReportFigs/CNN_NoiseDetect_precision_recall.png")

In [108]:
importlib.reload(Utilities.Plotting)
from Utilities.Plotting import *

In [210]:
import scipy.signal

dataset = feas1_ecg_data_val_clean
dataset["class_prediction"] = dataset["prediction"].map(lambda x: np.argmax(x))
#  & (dataset["noise_probs"] < 0)

selection = dataset[(dataset["class_prediction"] == 2) & (dataset["class_index"] == 1)]


for ecg_ind, ecg in selection.sample(frac=1).iterrows():
    print(ecg_ind)
    print(ecg[["measDiag", "prediction", "class_index"]])
    # filtered_ecg = scipy.signal.sosfiltfilt(sos, ecg["data"], padlen=150)

    plot_ecg(ecg["data"], 300, n_split=3, r_peaks=ecg["r_peaks"])
    plot_ecg_spectrogram(ecg["data"], 300, n_split=3, cut_range=[2, 18], figsize=(6, 2.5), export_quality=True)
    plot_ecg_drr(ecg["rri_feature"], ecg["rri_len"], export_quality=True)

    plt.show()

74351
measDiag                                 DiagEnum.AF
prediction     [-0.28620133, -0.32692084, 0.7239914]
class_index                                        1
Name: 74351, dtype: object
74431
measDiag                                DiagEnum.AF
prediction     [-0.7460612, 0.22486098, 0.49314404]
class_index                                       1
Name: 74431, dtype: object
74435
measDiag                                 DiagEnum.AF
prediction     [-0.017127324, -0.4723771, 0.8629926]
class_index                                        1
Name: 74435, dtype: object
74385
measDiag                                 DiagEnum.AF
prediction     [-0.76330775, 0.16884518, 0.48620653]
class_index                                        1
Name: 74385, dtype: object


KeyboardInterrupt: 

In [112]:
conf_mat_initial_transformer = np.array([[[ 717,   41,  226],
 [  88,  686,  129],
 [ 456,  184, 1407]], # CinC2020
[[   0, 2237, 1457],
 [   0,  437,   67],
 [   0, 1094,  555]],  # CinC2017
[[7384, 2281, 9848],
 [   0,    8,    8],
 [ 249,  151,  357]],  # Safer Feas2
[[ 8502,  2400, 11362],
 [    2,    50,    67],
 [  223,    75,   256]]])  # safer feas1

conf_mat_fine_tuned = np.array([[[  0, 320, 419],
 [  0,  91,  10],
 [  0, 171, 159]],  # CinC 2017
[[14918,  1543,  3052],
 [    3,     7,     6],
 [  562,    79,   116]],  # SAFER feas 2
[[17558,  2061,  2829],
 [    7,    92,    34],
 [  343,    60,   159]]])  # SAFER feas 1


conf_mat_final = np.array([[[20413,   376,  1475],
 [   13,    82,    38],
 [  203,    73,   282]],  # SAFER feas1
[[19113,    19,   381],
 [    7,     5,     4],
 [  428,    67,   262]]]) # SAFER feas2


for c in conf_mat_final:
    plot_confusion_matrix_2(c, ["Normal", "AF", "Other Rhythm"], colour="Blues")

In [228]:
# Plot with attention maps

import scipy.signal

dataset = feas1_test_interesting
dataset["class_prediction"] = dataset["prediction"].map(lambda x: np.argmax(x))
#  & (dataset["noise_probs"] < 0)

selection = dataset[(dataset["class_prediction"] == 1) & (dataset["class_index"] == 2)]



for ecg_ind, ecg in selection.dropna(subset=["attention"]).sample(frac=1).iterrows():
    print(ecg_ind)
    print(ecg[["measDiag", "prediction", "class_index"]])
    # filtered_ecg = scipy.signal.sosfiltfilt(sos, ecg["data"], padlen=150)
    plot_ecg_with_attention(ecg["data"][:3000], 300, n_split=1, attention=ecg["attention"][0][:, :96], figsize=(6, 4), export_quality=True)
    # plot_ecg_spectrogram(ecg["data"], 300, n_split=3, cut_range=[2, 18])
    # plot_ecg_poincare(ecg["rri_feature"], ecg["rri_len"])
    plt.show()



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



101033
measDiag                            DiagEnum.NoAF
prediction     [-2.428257, 0.87048155, 0.7332832]
class_index                                     2
Name: 101033, dtype: object
Plotting attention
54862
measDiag                             DiagEnum.NoAF
prediction     [-2.3245063, 0.9324973, 0.23180105]
class_index                                      2
Name: 54862, dtype: object
Plotting attention
15492
measDiag                             DiagEnum.NoAF
prediction     [-2.7731287, 1.894449, -0.38886702]
class_index                                      2
Name: 15492, dtype: object
Plotting attention
54477
measDiag                              DiagEnum.NoAF
prediction     [-3.2570727, 2.3353019, -0.08587719]
class_index                                       2
Name: 54477, dtype: object
Plotting attention
1445
measDiag                               DiagEnum.NoAF
prediction     [-0.45606652, 0.4858007, -0.10426653]
class_index                                        2
Name: 1445, dt

KeyboardInterrupt: 

In [77]:
# Compute patient wise performance

dataset = feas1_ecg_data_test_clean

dataset["class_prediction"] = dataset["prediction"].map(np.argmax)
dataset["pred_af"] = dataset["class_prediction"] == 1
pt_diagnoses = dataset.groupby("ptID")["pred_af"].any()
print(pt_diagnoses.value_counts())

False    225
True      94
Name: pred_af, dtype: int64


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset["class_prediction"] = dataset["prediction"].map(np.argmax)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dataset["pred_af"] = dataset["class_prediction"] == 1


In [68]:
pt_data["ptDiag"].value_counts()

4    2027
2      65
3      38
5       9
1       2
Name: ptDiag, dtype: int64

In [78]:
feas1_pt_data = pt_data[pt_data["ptDiag"].isin([4, 2, 1])]

In [79]:
feas1_pt_data.loc[:, "pt_prediction_af"] = pt_diagnoses

val_patients = feas1_pt_data.dropna(subset=["pt_prediction_af"])
pt_conf_mat = confusion_matrix((val_patients["ptDiag"] == 2).values, val_patients["pt_prediction_af"].astype(bool))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  feas1_pt_data.loc[:, "pt_prediction_af"] = pt_diagnoses


In [80]:
def print_binary_results(conf_mat):
    print("Confusion matrix:")
    print(conf_mat)

    print(f"Sensitivity: {conf_mat[1, 1]/np.sum(conf_mat[1]):0.3f}")
    print(f"Specificity: {(conf_mat[0, 0])/np.sum(conf_mat[0]):0.3f}")
    print("")

    print(f"Normal F1: {F1_ind(conf_mat, 0):0.3f}")
    print(f"AF F1: {F1_ind(conf_mat, 1):0.3f}")

print_binary_results(pt_conf_mat)

Confusion matrix:
[[219  80]
 [  2  10]]
Sensitivity: 0.833
Specificity: 0.732

Normal F1: 0.842
AF F1: 0.196


In [170]:
# Check the worst case how many ECGs would be reviewed

val_patients[val_patients["pt_prediction_af"] == 1]["noHQrecs"].sum()

1330

## Inspect the attention mechanism (not very useful yet)

In [213]:
# model.transformer_encoder.layers.
from plotly.subplots import make_subplots
fig = make_subplots(rows=2, cols=1)

def patch_attention(m):
    forward_orig = m.forward

    def wrap(*args, **kwargs):
        kwargs['need_weights'] = True
        kwargs['average_attn_weights'] = False

        return forward_orig(*args, **kwargs)

    m.forward = wrap

attentions = []
inds = []

def save_outputs(module, x, y):
    for att in y[1]:
        attentions.append(att.cpu().numpy())

patch_attention(model.transformer_encoder.layers[0].self_attn)
attention_hook = model.transformer_encoder.layers[0].self_attn.register_forward_hook(save_outputs)

model.eval()
with torch.no_grad():
    for i, (signals, labels, ind) in enumerate(feas1_test_interesting_dataloader):
        signal = signals[0].to(device).float()
        rris = signals[1].to(device).float()
        rri_len = signals[2].to(device).float()

        labels = labels.long().detach().numpy()
        for i in ind:
            if isinstance(i, str):
                inds.append(i)
            else:
                inds.append(i.item())

        output = model(signal, rris, rri_len) # rris).detach().to("cpu").numpy()
        # plot_ecg(signal[0].cpu().numpy())
        # plt.show()

attention_hook.remove()
# attentions = []

In [214]:
attentions = attentions[::2]

In [215]:
feas1_test_interesting["attention"] = pd.Series(data=attentions, index=inds)



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [191]:
attention_hook.remove()

### Inspect the attention pooling weights

In [170]:
# model.transformer_encoder.layers.
from plotly.subplots import make_subplots
fig = make_subplots(rows=2, cols=1)

attentions = None

def hook(module, x, y):
    global attentions
    print("hook")
    attentions = y[1].detach().to("cpu").numpy()

attention_hook = model.attention_pooling.attn.register_forward_hook(hook)

with torch.no_grad():
    for i, (signals, labels, ind) in enumerate(test_dataloader_safer):
        print(signals.shape)
        signals = torch.transpose(signals.to(device), 0, 1).float()
        # fft = torch.abs(torch.fft.fft(signals))
        # signals = torch.cat([signals, fft], dim=1)
        labels = labels.long().detach().numpy()
        output = model(signals).detach().to("cpu").numpy()

        if labels[0] == 0:
            print(attentions.shape)
            fig = make_subplots(2, 1)
            fig.add_trace(go.Scatter(y=signals[:, 0].detach().to("cpu").numpy()), row=1, col=1)
            for j in range(attentions.shape[-2]):
                fig.add_trace(go.Scatter(y=attentions[0, j, :]), row=2, col=1)
            fig.show()

        if i == 10:
            break

attention_hook.remove()

torch.Size([128, 9120])
hook


AttributeError: 'NoneType' object has no attribute 'append'

### Inspect the final classification layers

In [61]:
fc1_weight = model.decoder1.weight.data
fc2_weight = model.decoder2.weight.data

plt.imshow(fc1_weight.cpu().numpy())
plt.show()

In [67]:
dataset = test_dataset_safer
dataset["class_prediction"] = dataset["prediction"].map(lambda x: np.argmax(x))
selection = dataset[(dataset["class_prediction"] == 1) & (dataset["class_index"] == 1) & (dataset["noise_probs"]< 0)]

In [69]:
np_signal = np.vstack(selection["data"].values)
np_rri = np.vstack(selection["rri_feature"].values)
rri_len = selection["rri_len"].values

In [70]:
signal = torch.tensor(np_signal, dtype=torch.float, device=device)
rri = torch.tensor(np_rri, dtype=torch.float, device=device)
rri_lens = torch.tensor(rri_len, device=device)

encoder_out = None

def get_encoding(module, x, y):
    global encoder_out
    print("hook")
    print(x[0].shape)
    encoder_out = x[0].detach().to("cpu")

encoding_hook = model.decoder1.register_forward_hook(get_encoding)

output = model(signal, rri, rri_lens)

encoding_hook.remove()

# Now recreate the output from just the RRI or ECG and see which makes the biggest impact
fc1_weight = model.decoder1.weight.data.to("cpu")
fc2_weight = model.decoder2.weight.data.to("cpu")


ecg_out = fc1_weight[:, :128] @ torch.unsqueeze(encoder_out[:, :128], dim=-1)
rri_out = fc1_weight[:, 128:] @ torch.unsqueeze(encoder_out[:, 128:], dim=-1)

print(ecg_out.shape)
print(ecg_out.shape)

ecg_out = nn.functional.relu(ecg_out)
rri_out = nn.functional.relu(rri_out)

ecg_out = fc2_weight @ ecg_out
rri_out = fc2_weight @ rri_out

print(ecg_out.shape)
print(ecg_out.shape)

plt.figure()
plt.title("ECG Signal Outputs")
plt.imshow(ecg_out)
plt.colorbar()

plt.figure()
plt.title("RRI Sequence Outputs")
plt.imshow(rri_out)
plt.colorbar()

plt.show()

hook
torch.Size([17, 192])
torch.Size([17, 128, 1])
torch.Size([17, 128, 1])
torch.Size([17, 3, 1])
torch.Size([17, 3, 1])


### TSNE the encoder outputs

In [71]:
from sklearn.manifold import TSNE

In [128]:
encoder_out = []
class_indexes = []
inds = []

def get_encoding(module, x, y):
    encoder_out.append(x[0].detach().to("cpu").numpy())

encoding_hook = model.decoder1.register_forward_hook(get_encoding)

dataloader = test_dataloader_safer

with torch.no_grad():
        for i, (signals, labels, ind) in enumerate(dataloader):
            signal = signals[0].to(device).float()
            rris = signals[1].to(device).float()
            rri_len = signals[2].to(device).float()

            labels = labels.long().detach().numpy()
            class_indexes.append(labels)

            output = model(signal, rris, rri_len).detach().to("cpu").numpy()
            inds.append(ind.cpu().detach().numpy())

encoding_hook.remove()

encoder_out = np.concatenate(encoder_out, axis=0)
class_indexes = np.concatenate(class_indexes, axis=0)
inds = np.concatenate(inds, axis=0)

test_dataset_safer["encodings"] = pd.Series(data=[encoder_out[i] for i in range(encoder_out.shape[0])], index=inds)

tsne = TSNE(perplexity=30)
embeddings = tsne.fit_transform(encoder_out)

print(embeddings.shape)



(4078, 2)


In [132]:
tsne = TSNE(perplexity=10)
embeddings = tsne.fit_transform(encoder_out)

plt.scatter(embeddings[:, 0], embeddings[:, 1], c=class_indexes, marker="x")
plt.colorbar()

