In [1]:
import sys
import scipy.io
import scipy.signal as sig
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
from DataHandlers.DiagEnum import DiagEnum
import DataHandlers.SAFERDataset
import math
from ecgdetectors import Detectors

### Load SAFER data

In [40]:
feas2_pt_data, feas2_ecg_data = SAFERDataset.load_feas_dataset(2, "dataframe")

In [42]:
feas2_ecg_data = feas2_ecg_data[feas2_ecg_data["measDiag"].isin([DiagEnum.AF, DiagEnum.NoAF, DiagEnum.CannotExcludePathology])]

### Load CinC 2020 data

In [32]:
import CinC2020Dataset
import importlib
importlib.reload(CinC2020Dataset)

df = CinC2020Dataset.load_dataset(save_name="dataframe")

In [33]:
# At the moment we only select data with length which can be truncated to 3000 samples (10s)

def select_length(df):
    df_within_range = df[(df["length"] <= 5000) & (df["length"] >= 3000)].copy()
    df_within_range["data"] = df_within_range["data"].map(lambda x: x[:3000])
    df_within_range["length"] = df_within_range["data"].map(lambda x: x.shape[0])
    return df_within_range

df = select_length(df)
df["length"].value_counts()

3000    48030
Name: length, dtype: int64

In [34]:
print(df["heartrate"].min())
print(df["heartrate"].max())

plt.hist(df["heartrate"])
plt.show()

18.0
186.0


In [12]:
df = df[(df["heartrate"] > 50) & (df["heartrate"] < 100)]

In [5]:
df = df[df["class_index"].isin([0, 1])]

In [9]:
df = df[(df["heartrate"] < 50)]["diag_num"].value_counts()

In [21]:
plt.hist(df[df["diag_num"].map(lambda x: 427084000 in x)]["heartrate"])
plt.show()

In [31]:
df[df["heartrate"] < 50]["diag_num"].map(lambda x: 426177001 in x).value_counts()

False    546
True     367
Name: diag_num, dtype: int64

### Define the model

In [3]:
import torch
from torch import nn
from torch import functional as F

### Scratch space

In [28]:
max_len = 181
d_model = 32

pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(500.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)

In [60]:
pe.shape
fig = go.Figure()
fig.add_trace(go.Heatmap(z=pe[:, 0, :].T))
fig.show()

In [4]:
# Attention pooling

class SelfAttentionPooling(nn.Module):
    """
    Implementation of SelfAttentionPooling
    Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
    https://arxiv.org/pdf/2008.01077v1.pdf
    """
    def __init__(self, input_dim):
        super(SelfAttentionPooling, self).__init__()
        self.W = nn.Linear(input_dim, 1)

    def forward(self, batch_rep):
        """
        input:
            batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension

        attention_weight:
            att_w : size (N, T, 1)

        return:
            utter_rep: size (N, H)
        """
        softmax = nn.functional.softmax
        att_w = softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)

        return utter_rep

In [5]:
from typing import Callable
from torch import Tensor

class AttentionPool2d(nn.Module):
    "Attention for Learned Aggregation"
    def __init__(self,
        ni:int,
        bias:bool=True,
        norm:Callable[[int], nn.Module]=nn.LayerNorm
    ):
        super().__init__()
        self.norm = norm(ni)
        self.q = nn.Linear(ni, ni, bias=bias)
        self.vk = nn.Linear(ni, ni*2, bias=bias)
        self.proj = nn.Linear(ni, ni)

    def forward(self, x:Tensor, cls_q:Tensor):
        x = self.norm(x.flatten(2))
        B, N, C = x.shape

        q = self.q(cls_q.expand(B, N, -1))
        k, v = self.vk(x).reshape(B, N, 2, C).permute(2, 0, 1, 3).chunk(2, 0)
        k = k[0, :, :, :]
        v = v[0, :, :, :]

        attn = q @ k.transpose(-2, -1)
        attn = nn.functional.softmax(attn, dim=-1)
        # Once we have the attention take the mean over the time domain
        x = torch.mean((attn @ v), dim=-2)
        return x


class LearnedAggregation(nn.Module):
    "Learned Aggregation from https://arxiv.org/abs/2112.13692"
    def __init__(self,
        ni:int,
        attn_bias:bool=True,
        ffn_expand:int|float=3,
        norm:Callable[[int], nn.Module]=nn.LayerNorm,
        act_cls:Callable[[None], nn.Module]=nn.GELU,
    ):
        super().__init__()
        self.gamma_1 = nn.Parameter(1e-4 * torch.ones(ni))
        self.gamma_2 = nn.Parameter(1e-4 * torch.ones(ni))
        self.cls_q = nn.Parameter(torch.zeros(ni))
        self.attn = AttentionPool2d(ni, attn_bias, norm)
        self.norm = norm(ni)
        self.ffn = nn.Sequential(
            nn.Linear(ni, int(ni*ffn_expand)),
            act_cls(),
            nn.Linear(int(ni*ffn_expand), ni)
        )
        nn.init.trunc_normal_(self.cls_q, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x:Tensor):
        x = self.cls_q + self.gamma_1 * self.attn(x, self.cls_q)
        return x + self.gamma_2 * self.ffn(self.norm(x))

    @torch.no_grad()
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

class CNNFeatureExtraction(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv_section1 = nn.Sequential(
            nn.Conv1d(1, 32, 17, stride=2, padding=8),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.BatchNorm1d(32)
        )

        self.conv_section2 = nn.Sequential(
            nn.Conv1d(32, 64, 11, stride=2, padding=5),
            nn.ReLU(),
            nn.MaxPool1d(2, padding=1),
            nn.BatchNorm1d(64)
        )

        self.conv_section3 = nn.Sequential(
            nn.Conv1d(64, 64, 11, stride=2, padding=5),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.BatchNorm1d(64)
        )

        self.conv_section4 = nn.Sequential(
            nn.Conv1d(64, 128, 5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.BatchNorm1d(128)
        )

        self.conv_section5 = nn.Sequential(
            nn.Conv1d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2,  padding=1),
            nn.BatchNorm1d(128)
        )

        self.conv_section6 = nn.Sequential(
            nn.Conv1d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool1d(2,  padding=1),
            nn.BatchNorm1d(128)
        )

    def forward(self, x):
         # [1, 1000]
        x = self.conv_section1(x)

        # [32, 250]
        x = self.conv_section2(x)

        # [64, 63]
        x = self.conv_section3(x)

        # [64, 16]
        x = self.conv_section4(x)

        # [128, 8]
        x = self.conv_section5(x)

        # [128, 4]
        x = self.conv_section6(x)
        # [128, 2]

        return torch.flatten(x, -2)

In [6]:
# Borrowed and modified from https://github.com/pytorch/examples/blob/main/word_language_model/model.py

# Temporarily leave PositionalEncoding module here. Will be moved somewhere else.
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
       # >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=200):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(500.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    """Container module with an encoder, a recurrent or transformer module, and a decoder."""

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, inplen, dropout=0.1):
        super(TransformerModel, self).__init__()
        try:
            from torch.nn import TransformerEncoder, TransformerEncoderLayer
        except BaseException as e:
            raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or '
                              'lower.') from e

        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.ninp = ninp
        # self.attention_pooling = SelfAttentionPooling(ninp)
        self.attention_pooling = LearnedAggregation(ninp)
        self. layer_norm = nn.LayerNorm(ninp)
        # self.spectrogram_bn = nn.BatchNorm2d(1)
        self.decoder1 = nn.Linear(ninp, 128)
        self.decoder2 = nn.Linear(128, ntoken)
        self.n_fft = n_fft

        self.cnn_window_size = 1000
        self.cnn_stride = 500
        self.cnn = CNNFeatureExtraction()

        self.sigmoid = nn.Sigmoid()

        self.init_weights()


    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask


    def init_weights(self):
        initrange = 0.1
        nn.init.zeros_(self.decoder1.bias)
        nn.init.zeros_(self.decoder2.bias)
        nn.init.uniform_(self.decoder1.weight, -initrange, initrange)
        nn.init.uniform_(self.decoder2.weight, -initrange, initrange)

    def stft_and_reshape(self, src):
        src = torch.transpose(src, 0, 1)
        ecg_stfts = torch.stft(src, n_fft=n_fft, return_complex=True)
        src = torch.abs(ecg_stfts)

        # Layer norm
        src = (src - torch.mean(src, dim=1)[:, None, :])/torch.std(src, dim=1)[:, None, :]

        # Batch norm
        # src = self.spectrogram_bn(torch.unsqueeze(src, dim=1))[:, 0, :, :]

        src = torch.permute(src, [2, 0, 1])
        src = src[:, :, :self.ninp]

        return src

    def window(self, src):
        src = torch.transpose(src, 0, 1)
        src = src.unfold(1, self.cnn_window_size, self.cnn_stride)

        return src

    def forward(self, src):
        # src = self.stft_and_reshape(src)
        src = self.window(src)
        B, N, _ = src.shape
        # Unsqueeze for the CNN channel dimension, and flatten time and batch dimension
        src = torch.flatten(torch.unsqueeze(src, -2), 0, 1)
        src = self.cnn(src)

        src = torch.reshape(src, (N, B, -1))
        src = self.pos_encoder(src)

        output = self.transformer_encoder(src)
        output = torch.transpose(output, 0, 1)
        # output = torch.flatten(output, 1)
        output = self.attention_pooling(output)

        output = self.decoder1(output)
        output = nn.functional.relu(output)

        output = self.decoder2(output)
        # output = self.sigmoid(output)

        return output

In [None]:
model = TransformerModel(2, 256, 4, 1024, 4, 47, 512)
for param in model.parameters():

### Generate dataloaders

In [39]:
mapper = CinC2020Dataset.CinC2020DiagMapper()
num_unique_classes = len(mapper.diag_desc.index)
mapper.diag_desc.index

def class_index_map(diag):
    if diag == DiagEnum.NoAF:
        return 0
    elif diag == DiagEnum.AF:
        return 1
    elif diag == DiagEnum.CannotExcludePathology:
        return 2
    elif diag == DiagEnum.Undecided:
        return 0

def one_hot_map(diag_num):
    one_hot = np.zeros(num_unique_classes)
    one_hot[diag_num] = 1
    return one_hot

In [40]:
df

Unnamed: 0,data,fs,adc_gain,age,sex,diag_num,length,filepath,overall_diags,chal_diag_num,measDiag,class_index,r_peaks,heartrate
0,"[0.013503328064632175, 0.15966221418245521, 0....",300,1000.0,74,Sex.Male,[59118001],3000,C:\Users\daniel\Documents\CambridgeSoftwarePro...,,[17],DiagEnum.HeartBlock,2,"[97, 323, 503, 681, 858, 1036, 1213, 1395, 157...",100.000000
1,"[-0.009813904287011859, -0.003005242496288474,...",300,1000.0,49,Sex.Female,[426783006],3000,C:\Users\daniel\Documents\CambridgeSoftwarePro...,,[0],DiagEnum.NoAF,0,"[95, 191, 436, 689, 907, 1158, 1378, 1611, 184...",84.000000
2,"[0.05191035007645714, -0.06636734822370909, -0...",300,1000.0,81,Sex.Female,[164889003],3000,C:\Users\daniel\Documents\CambridgeSoftwarePro...,,[1],DiagEnum.AF,1,"[98, 248, 341, 462, 607, 737, 836, 969, 1151, ...",144.000000
3,"[-0.2510777303744558, 0.15572253059268046, 0.2...",300,1000.0,45,Sex.Male,[164889003],3000,C:\Users\daniel\Documents\CambridgeSoftwarePro...,,[1],DiagEnum.AF,1,"[99, 218, 478, 624, 770, 935, 1113, 1297, 1459...",115.513393
5,"[0.012547497293927757, 0.01236343167276485, -0...",300,1000.0,29,Sex.Male,[59118001],3000,C:\Users\daniel\Documents\CambridgeSoftwarePro...,,[17],DiagEnum.HeartBlock,2,"[101, 204, 489, 762, 1023, 1290, 1582, 1846, 2...",68.571429
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
56053,"[0.010536135052487638, -0.000324369066614873, ...",300,1000.0,66,Sex.Undefined,[427172004],3000,C:\Users\daniel\Documents\CambridgeSoftwarePro...,"[425856008, 426627000, 164884008, 164865005]",[12],DiagEnum.CannotExcludePathology,2,"[112, 303, 602, 849, 1033, 1353, 1601, 1790, 2...",72.000000
56054,"[0.02987876366143985, 0.019229092676427178, 0....",300,1000.0,66,Sex.Undefined,[427172004],3000,C:\Users\daniel\Documents\CambridgeSoftwarePro...,"[425856008, 426627000, 164884008, 164865005]",[12],DiagEnum.CannotExcludePathology,2,"[119, 363, 612, 866, 1049, 1379, 1635, 1895, 2...",72.000000
56055,"[-0.05881361546820613, -0.03620388268373439, -...",300,1000.0,66,Sex.Undefined,[427172004],3000,C:\Users\daniel\Documents\CambridgeSoftwarePro...,"[425856008, 426627000, 164884008, 164865005]",[12],DiagEnum.CannotExcludePathology,2,"[95, 412, 664, 865, 1148, 1392, 1575, 1886, 21...",72.000000
56056,"[0.013912807193116584, 0.014729595082520405, 0...",300,1000.0,66,Sex.Undefined,[427172004],3000,C:\Users\daniel\Documents\CambridgeSoftwarePro...,"[425856008, 426627000, 164884008, 164865005]",[12],DiagEnum.CannotExcludePathology,2,"[102, 196, 460, 732, 1003, 1239, 1542, 1809, 2...",72.000000


In [41]:
# Onehot encoding
from torch.utils.data import Dataset, DataLoader

class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, dataset):
        'Initialization'
        self.dataset = dataset

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset.index)

  def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        row = self.dataset.iloc[index]

        X = row["data"]
        y = row["class_index"]
        ind = row.name

        return X, y, ind

In [140]:
feas2_ecg_data["class_index"] = feas2_ecg_data["measDiag"].map(lambda x: class_index_map(x))

NameError: name 'feas2_ecg_data' is not defined

In [None]:
# For SAFER data
# Split train and test data according to each patient
def make_SAFER_dataloaders(pt_data, ecg_data, test_frac):
    pt_data["noLQrecs"] = pt_data["noRecs"] - pt_data["noHQrecs"]  # for Feas1 this might include stuff flagged by zenicor as noisy?
    train_patients = []
    test_patients = []

    for val, df in pt_data.groupby("noLQrecs"):
        print(f"processing {val}")
        print(f"number of patients {len(df.index)}")
        test = df.sample(frac=test_frac)
        test_patients.append(test)
        train_patients.append(df[~df["ptID"].isin(test["ptID"])])

    train_pt_df = pd.concat(train_patients)
    test_pt_df = pd.concat(test_patients)

    print(f"Test high quality: {test_pt_df['noHQrecs'].sum()} low quality: {test_pt_df['noLQrecs'].sum()} ")
    print(f"Train high quality: {train_pt_df['noHQrecs'].sum()} low quality: {train_pt_df['noLQrecs'].sum()} ")

    train_dataloader = None
    test_dataloader = None

    train_dataset = None
    test_dataset = None

    if not train_pt_df.empty:
        # get ECG datasets
        train_dataset = ecg_data[ecg_data["ptID"].isin(train_pt_df["ptID"])]
        # Normalise
        train_dataset["data"] = (train_dataset["data"] - train_dataset["data"].map(lambda x: x.mean()))/train_dataset["data"].map(lambda x: x.std())
        torch_dataset_train = Dataset(train_dataset)
        train_dataloader = DataLoader(torch_dataset_train, batch_size=32, shuffle=True, pin_memory=True)

    if not test_pt_df.empty:
        test_dataset = ecg_data[(ecg_data["ptID"].isin(test_pt_df["ptID"])) & (ecg_data["measDiag"] != DiagEnum.Undecided)]
        test_dataset["data"] = (test_dataset["data"] - test_dataset["data"].map(lambda x: x.mean()))/test_dataset["data"].map(lambda x: x.std())
        torch_dataset_test = Dataset(test_dataset)
        test_dataloader = DataLoader(torch_dataset_test, batch_size=32, shuffle=True, pin_memory=True)

    return train_dataloader, test_dataloader, train_dataset, test_dataset

train_dataloader, test_dataloader, train_dataset, test_dataset = make_SAFER_dataloaders(feas2_pt_data, feas2_ecg_data, test_frac=1)

# Remake the test dataset without any undecided values
test_dataset = test_dataset[test_dataset["measDiag"] != DiagEnum.Undecided]
torch_dataset_test = Dataset(test_dataset)
test_dataloader = DataLoader(torch_dataset_test, batch_size=32, shuffle=True, pin_memory=True)

In [42]:
### Make dataloaders for CinC data
from sklearn.model_selection import train_test_split

train_dataset, test_dataset = train_test_split(df, test_size=0.2, stratify=df["class_index"])
test_dataset = test_dataset[test_dataset["measDiag"] != DiagEnum.Undecided]  # Should just remove any errors in loading the dataset

torch_dataset_test = Dataset(test_dataset)
test_dataloader = DataLoader(torch_dataset_test, batch_size=128, shuffle=True, pin_memory=True)

torch_dataset_train = Dataset(train_dataset)
train_dataloader = DataLoader(torch_dataset_train, batch_size=128, shuffle=True, pin_memory=True)

In [43]:
test_dataset["class_index"].value_counts()

2    5256
0    3528
1     822
Name: class_index, dtype: int64

In [36]:
# Remove Cannot exclude pathology from the testing set (but possibly still in train)

test_dataset = test_dataset[(test_dataset["measDiag"] != DiagEnum.CannotExcludePathology) & (test_dataset["measDiag"] != DiagEnum.HeartBlock)]
torch_dataset_test = Dataset(test_dataset)
test_dataloader = DataLoader(torch_dataset_test, batch_size=128, shuffle=True, pin_memory=True)

In [154]:
train_dataset["measDiag"].value_counts()

DiagEnum.NoAF    12569
DiagEnum.AF       7915
Name: measDiag, dtype: int64

### Prepare for training

In [44]:
if torch.cuda.is_available():
    print("Using Cuda")
    device = torch.device("cuda")
else:
    print("Using CPU")
    device = torch.device("cpu")

Using Cuda


In [45]:
from torch.optim.lr_scheduler import StepLR, LambdaLR, SequentialLR

In [46]:
num_epochs = 5
n_head = 4
embed_dim = 384 # int(n_fft/2)

model = TransformerModel(2, embed_dim, n_head, 1024, 4, 47).to(device)

# Use weightings to handle class imbalance

class_counts = torch.tensor(train_dataset["class_index"].value_counts().sort_index().values.astype(np.float32))
class_weights = (1/class_counts)
class_weights /= torch.sum(class_weights)
print(class_weights)

"""
c = {i: 0 for i in mapper.diag_desc.index}
for _, diags in df["diag_num"].iteritems():
    for d in diags:
        c[d] += 1

class_counts = torch.tensor(list(c.values()))
class_weights = (1/class_counts)
class_weights /= torch.sum(class_weights)
"""

def multiclass_cross_entropy_loss(pred, targets):
    return - torch.sum(class_weights[None, :] * targets[None, :] * torch.log(pred))\
           - torch.sum(class_weights[None, :] * (1 - targets[None, :]) * torch.log(1-pred))

loss_func = nn.CrossEntropyLoss(weight=class_weights) # multiclass_cross_entropy_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

number_warmup_epochs = 3
def warmup(current_step: int):
    return 1 / (10 ** (float(number_warmup_epochs - current_step)))
warmup_scheduler = LambdaLR(optimizer, lr_lambda=warmup)

scheduler = SequentialLR(optimizer, [warmup_scheduler, scheduler], [number_warmup_epochs])


NameError: name 'n_fft' is not defined

In [71]:
# Debugging training issues
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x2860c7ada50>

In [72]:
import copy
model = model.to(device)
num_epochs = 30

def train(model):
    best_test_loss = 100
    best_epoch = -1
    best_model = copy.deepcopy(model).cpu()

    losses = []

    for epoch in range(num_epochs):
        total_loss = 0
        print(f"starting epoch {epoch} ...")
        # Train
        num_batches = 0
        model.train()
        for i, (signals, labels, _) in enumerate(train_dataloader):
            signals = torch.transpose(signals.to(device), 0, 1).float()
            # fft = torch.abs(torch.fft.fft(signals))
            # signals = torch.cat([signals, fft], dim=1)

            if torch.any(torch.isnan(signals)):
                print("Signals are nan")
                continue
            labels = labels.long()
            optimizer.zero_grad()
            output = model(signals).to("cpu")
            loss = loss_func(output, labels)
            if torch.isnan(loss):
                raise ValueError
            loss.backward()
            optimizer.step()
            num_batches += 1
            total_loss += float(loss)

        print(f"Epoch {epoch} finished with average loss {total_loss/num_batches}")
        print("Testing ...")
        # Test
        num_test_batches = 0
        test_loss = 0
        with torch.no_grad():
            model.eval()
            for i, (signals, labels, _) in enumerate(test_dataloader):
                signals = torch.transpose(signals.to(device), 0, 1).float()
                # fft = torch.abs(torch.fft.fft(signals))
                # signals = torch.cat([signals, fft], dim=1)
                if torch.any(torch.isnan(signals)):
                    print("Signals are nan")
                    continue

                labels = labels.long()
                output = model(signals).to("cpu")
                loss = loss_func(output, labels)
                test_loss += float(loss)
                num_test_batches += 1

        print(f"Average test loss: {test_loss/num_test_batches}")
        losses.append([total_loss/num_batches, test_loss/num_test_batches])

        if test_loss/num_test_batches < best_test_loss:
            best_model = copy.deepcopy(model).cpu()
            best_test_loss = test_loss/num_test_batches
            best_epoch = epoch
        else:
            if best_epoch + 5 <= epoch:
                return best_model, losses

        scheduler.step()

    return best_model, losses

model, losses = train(model)
model = model.to(device)

starting epoch 0 ...
Epoch 0 finished with average loss 0.6931325487729882
Testing ...
Average test loss: 0.69313010374705
starting epoch 1 ...
Epoch 1 finished with average loss 0.693138711091851
Testing ...
Average test loss: 0.6931260287761688
starting epoch 2 ...
Epoch 2 finished with average loss 0.6932617115373371
Testing ...
Average test loss: 0.6930919865767161
starting epoch 3 ...




Epoch 3 finished with average loss 0.6947976856672463
Testing ...
Average test loss: 0.6931729336579641
starting epoch 4 ...
Epoch 4 finished with average loss 0.6932880322472388
Testing ...
Average test loss: 0.6930965463320414
starting epoch 5 ...
Epoch 5 finished with average loss 0.6931935889380318
Testing ...
Average test loss: 0.6930646896362305
starting epoch 6 ...
Epoch 6 finished with average loss 0.6931723877161491
Testing ...
Average test loss: 0.6930474638938904
starting epoch 7 ...
Epoch 7 finished with average loss 0.6931819439936084
Testing ...
Average test loss: 0.6931290686130523
starting epoch 8 ...
Epoch 8 finished with average loss 0.6931540765682188
Testing ...
Average test loss: 0.6930527746677398
starting epoch 9 ...
Epoch 9 finished with average loss 0.6931493467643481
Testing ...
Average test loss: 0.6931114415327708
starting epoch 10 ...
Epoch 10 finished with average loss 0.6931343925099412
Testing ...
Average test loss: 0.6930753449598949
starting epoch 11 .

In [155]:
# Save a model
torch.save(model.state_dict(), "TrainedModels/Transformer_spectrogram_features_learned_aggregation.pt")
train_dataset.to_pickle("TrainedModels/Transformer_spectrogram_features_learned_aggregation_train_set.pk")

In [32]:
# Load a model
model = TransformerModel(2, embed_dim, n_head, 1024, 4, 47, n_fft).to(device)
model.load_state_dict(torch.load("TrainedModels/Transformer_spectrogram_features_learned_aggregation.pt", map_location=device))

# Should load the test data as well

<All keys matched successfully>

### Model testing

In [50]:
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, multilabel_confusion_matrix

test_dataset["prediction"] = None

def get_predictions(model, dataloader, dataset):
    model.eval()

    true_labels = []
    predictions = []

    outputs = []
    inds = []

    with torch.no_grad():
        for i, (signals, labels, ind) in enumerate(dataloader):
            signals = torch.transpose(signals.to(device), 0, 1).float()
            # fft = torch.abs(torch.fft.fft(signals))
            # signals = torch.cat([signals, fft], dim=1)
            labels = labels.long().detach().numpy()
            true_labels.append(labels)

            output = model(signals).detach().to("cpu").numpy()

            prediction = output # np.argmax(output, axis=-1)
            predictions.append(prediction)

            for i, o in zip(ind, output):
                outputs.append(o)
                inds.append(int(i))

    dataset["prediction"] = pd.Series(data=outputs, index=inds)

    predictions = np.concatenate(predictions)
    true_labels = np.concatenate(true_labels)

    return predictions, true_labels

predictions, true_labels = get_predictions(model, test_dataloader, test_dataset)
conf_mat = confusion_matrix(true_labels, np.argmax(predictions, axis=1))

In [44]:
predictions, true_labels = get_predictions(model, test_dataloader, test_dataset)
conf_mat = multilabel_confusion_matrix(true_labels, predictions > 0.5)

  att_w = softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1)


In [45]:
for i, c_mat in enumerate(conf_mat):
    print(mapper.mapToDesc(i))
    print(c_mat)

sinus rhythm
[2935   69  524]
atrial fibrillation
[ 46 686  90]
atrial flutter
[0 0 0]


In [46]:
test_dataset["class_index"].value_counts()

0    3528
1     822
Name: class_index, dtype: int64

In [51]:
# Same as the below function (as described in CinC)
def F1_ind(conf_mat, ind):
    return (2 * conf_mat[ind, ind])/(np.sum(conf_mat[ind]) + np.sum(conf_mat[:, ind]))

print("Confusion matrix:")
print(conf_mat)

print(f"Sensitivity: {conf_mat[1, 1]/np.sum(conf_mat[1])}")
print(f"Specificity: {conf_mat[0, 0]/np.sum(conf_mat[0])}")

print(f"Normal F1: {F1_ind(conf_mat, 0)}")
print(f"AF F1: {F1_ind(conf_mat, 1)}")
print(f"Other F1: {F1_ind(conf_mat, 2)}")
# print(f"Other arrythmia F1: {F1_ind(conf_mat, 1)}")

Confusion matrix:
[[2959   82  487]
 [  25  729   68]
 [1916  339 3001]]
Sensitivity: 0.8868613138686131
Specificity: 0.838718820861678
Normal F1: 0.7021831988609397
AF F1: 0.7393509127789046
Other F1: 0.6811166591012257


In [53]:
test_dataset["class_prediction"] = test_dataset["prediction"].map(lambda x: np.argmax(x))

In [54]:
selection = test_dataset[(test_dataset["class_prediction"] == 0) & (test_dataset["class_index"] == 2)]

from matplotlib.ticker import AutoMinorLocator

def plot_ecg(x, fs=300, n_split=1):
    sample_len = x.shape[0]
    time_axis = np.arange(sample_len)/fs

    cuts = np.round(np.linspace(0, sample_len-1, n_split+1)).astype(int)

    fig, ax = plt.subplots(n_split, 1, figsize=(12, 5), squeeze=False)
    for j in range(n_split):
        ax[j][0].plot(time_axis[cuts[j]:cuts[j+1]], x[cuts[j]:cuts[j+1]])
        ax[j][0].set_xlabel("Time")
        ax[j][0].set_xlim((time_axis[cuts[j]], time_axis[cuts[j+1]]))

        t_s = time_axis[cuts[j]]
        t_f = time_axis[cuts[j+1]]
        time_ticks = np.arange(t_s - t_s%0.2, t_f + (0.2 - t_f%0.2), 0.2)
        decimal_labels = ~np.isclose(time_ticks, np.round(time_ticks))
        time_labels = np.round(time_ticks).astype(int).astype(str)
        time_labels[decimal_labels] = ""

        ax[j][0].set_xticks(time_ticks, time_labels)


        ax[j][0].xaxis.set_minor_locator(AutoMinorLocator(5))
        ax[j][0].yaxis.set_minor_locator(AutoMinorLocator(5))

        ax[j][0].grid(which='major', linestyle='-', linewidth='0.5', color='black')
        ax[j][0].grid(which='minor', linestyle='-', linewidth='0.5', color='lightgray')

    plt.show()

c = DiagEnum.CannotExcludePathology

for _, ecg in selection.sample(frac=1).iterrows():
    print(ecg[["measDiag", "prediction", "diag_num"]])
    plot_ecg(ecg["data"], 300)
    plt.show()

measDiag                      DiagEnum.HeartBlock
prediction    [1.0444555, -1.3554548, 0.93679714]
diag_num                               [59118001]
Name: 6512, dtype: object
measDiag        DiagEnum.CannotExcludePathology
prediction    [2.1256232, -2.3721557, 1.127521]
diag_num                            [251146004]
Name: 11848, dtype: object
measDiag         DiagEnum.CannotExcludePathology
prediction    [2.6275034, -2.9053438, 1.3200564]
diag_num                   [39732003, 426783006]
Name: 28617, dtype: object
measDiag        DiagEnum.CannotExcludePathology
prediction    [1.8434985, -2.044042, 0.9577774]
diag_num                 [426783006, 427393009]
Name: 32937, dtype: object
measDiag         DiagEnum.CannotExcludePathology
prediction    [2.3725977, -2.8076935, 1.5380838]
diag_num                             [270492004]
Name: 3177, dtype: object
measDiag                     DiagEnum.HeartBlock
prediction    [2.6998997, -2.7097263, 0.8894341]
diag_num                             

KeyboardInterrupt: 

### Inspect the attention mechanism

In [172]:
# model.transformer_encoder.layers.
from plotly.subplots import make_subplots
fig = make_subplots(rows=2, cols=1)

def hook(module, x, y):
    fig = make_subplots(rows=2, cols=1)
    fig.add_trace(go.Heatmap(z=x[0][:, 0, :].cpu().numpy()), row=1, col=1)
    fig.add_trace(go.Heatmap(z=y[0][:, 0, :].cpu().numpy()), row=2, col=1)
    fig.show()

attention_hook = model.transformer_encoder.layers[0].self_attn.register_forward_hook(hook)

with torch.no_grad():
    for i, (signals, labels, ind) in enumerate(test_dataloader):
        print(signals.shape)
        fig = go.Figure()
        fig.add_trace(go.Scatter(y=signals[0]))
        fig.show()
        signals = torch.transpose(signals.to(device), 0, 1).float()
        # fft = torch.abs(torch.fft.fft(signals))
        # signals = torch.cat([signals, fft], dim=1)
        labels = labels.long().detach().numpy()

        output = model(signals).detach().to("cpu").numpy()
        break

attention_hook.remove()

torch.Size([32, 3000])


In [167]:
attention_hook.remove()