In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.preprocessing import minmax_scale, scale
import os
from SubModeles.EcgDataset import ECGDataset
import CustomStatisticCalculation as np_stat
from SubModeles.AutoencoderModels import ECGAutoencoder
from tabulate import tabulate
from array_gzip_io_utils import load_from_gz_file

In [None]:
# Constants
SIGNAL_LENGTH = 8000
TEST_SET_SIZE = 0.2
EXPERIMENT = 1
ROOT_DATA_PATH = f"TempData/Data/{SIGNAL_LENGTH}/Test-set-{TEST_SET_SIZE}/{EXPERIMENT}"
TRAINED_MODEL_PATH = f"TempData/Models/{SIGNAL_LENGTH}/Test-set-{TEST_SET_SIZE}/{EXPERIMENT}"
os.makedirs(TRAINED_MODEL_PATH, exist_ok=True)

In [None]:
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

def prediction_post_processing(predictions):
    window_size = 70
    all_post_processed = []
    for prediction in predictions:
        prediction = [item if item > 0.10 else 0 for item in prediction]

        prediction_size = len(prediction)
        post_processed = np.zeros(prediction_size, dtype=np.float32)
        index = 0
        while index < prediction_size:
            if prediction[index] > 0:
                prediction_range = prediction[index:window_size+index]
                max_value = np.argmax(prediction_range)
                post_processed[max_value + index] = 1

                index = index + window_size - 1
            index+=1

        all_post_processed.append(post_processed)
    return np.array(all_post_processed)

In [None]:
# Dataset loading

train_data = load_from_gz_file(f'{ROOT_DATA_PATH}/train-mit-arrhythmia-fs-400-prefered-leads.pkl.gz')
#train_data = np.concatenate((train_data, load_from_gz_file(f'{ROOT_DATA_PATH}/train-china-signal-chalenge-2020-fs-400-prefered-leads.pkl.gz')))
train_data = np.concatenate((train_data, load_from_gz_file(f'{ROOT_DATA_PATH}/train-qt-fs-400-prefered-leads.pkl.gz')))

np.random.shuffle(train_data)

train_data, validation_data = train_test_split(train_data, test_size=0.1, random_state=42)
print(f'Train data set: {len(train_data)}')

def __collate_fn(data):
    data = np.array(data, dtype=object)

    signals = np.stack(data[:, 0])
    targets = np.stack(data[:, 1])
    knowledge = np.stack(data[:, 2])

    signal_tensor = torch.from_numpy(np.expand_dims(signals , 1))
    target_tensor = torch.from_numpy(np.expand_dims(targets , 1))
    knowledge_tensor = torch.from_numpy(np.expand_dims(knowledge , 1))

    return signal_tensor, target_tensor, knowledge_tensor, data[:, 3], data[:, 4]

train_dataloader = DataLoader(ECGDataset(train_data), batch_size=16, shuffle = True, drop_last = True, collate_fn = __collate_fn)
validation_dataloader = DataLoader(ECGDataset(validation_data), batch_size=16, shuffle = True, drop_last = True, collate_fn = __collate_fn)

In [None]:
# Train CNN

model = ECGAutoencoder(2)
model.to(device)

criterion = nn.BCELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 20
print(f'Epoch \t Training Loss \t Validation Loss')

for epoch in range(num_epochs):
    train_loss = 0.0
    train_accurancy = []
    model.train() 

    correct = 0
    for batch in train_dataloader:
        signal_tensor, target_tensor, knowledge_tensor, origin_r_peak, annotation = batch

        input_marged = torch.cat((signal_tensor, knowledge_tensor), dim=1).float()
        input_marged = input_marged.to(device)

        optimizer.zero_grad()

        outputs = model(input_marged)

        target_tensor = target_tensor.to(device)
        loss = criterion(outputs, target_tensor)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

 # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in validation_dataloader:
            signal_tensor, target_tensor, knowledge_tensor, origin_r_peak, annotation = batch

            input_marged = torch.cat((signal_tensor, knowledge_tensor), dim=1).float()
            input_marged = input_marged.to(device)

            outputs = model(input_marged)

            target_tensor = target_tensor.to(device)
            loss = criterion(outputs, target_tensor)
            val_loss += loss.item()

    epoch_train_loss = train_loss / len(train_dataloader)
    epoch_val_loss = val_loss / len(validation_dataloader)

    print(f'{epoch+1}/{num_epochs}\t{epoch_train_loss:.6f}\t{epoch_val_loss:.6f}')

torch.save(model.state_dict(), os.path.join(TRAINED_MODEL_PATH, "ECGAutoencoder-2-mit-china-qt-prefered-leads-fs-400-45epoch.pt"))

In [None]:
#continue training

model = ECGAutoencoder(2)
model.to(device)
model.load_state_dict(torch.load(os.path.join(TRAINED_MODEL_PATH, "ECGAutoencoder-2-mit-china-qt-prefered-leads-fs-400-45epoch.pt")))

criterion = nn.BCELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 6
for epoch in range(num_epochs):
    train_loss = 0.0
    model.train()

    for batch in train_dataloader:
        signal_tensor, target_tensor, knowledge_tensor, origin_r_peak, annotation = batch

        input_marged = torch.cat((signal_tensor, knowledge_tensor), dim=1).float()
        input_marged = input_marged.to(device)
        
        optimizer.zero_grad()
        outputs = model(input_marged)

        target_tensor = target_tensor.to(device)
        loss = criterion(outputs, target_tensor)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()


 # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in validation_dataloader:
            signal_tensor, target_tensor, knowledge_tensor, origin_r_peak, annotation = batch

            input_marged = torch.cat((signal_tensor, knowledge_tensor), dim=1).float()
            input_marged = input_marged.to(device)

            outputs = model(input_marged)

            target_tensor = target_tensor.to(device)
            loss = criterion(outputs, target_tensor)
            val_loss += loss.item()

    epoch_train_loss = train_loss / len(train_dataloader)
    epoch_val_loss = val_loss / len(validation_dataloader)

    print(f'{epoch+1}/{num_epochs}\t{epoch_train_loss:.6f}\t{epoch_val_loss:.6f}')

torch.save(model.state_dict(), os.path.join(TRAINED_MODEL_PATH, "ECGAutoencoder-2-mit-china-qt-prefered-leads-fs-400-60epoch-CT1.pt"))

In [None]:
# Testing Dataset loading
test_data = load_from_gz_file(f'{ROOT_DATA_PATH}/test-mit-arrhythmia-fs-400-prefered-leads.pkl.gz')
# test_data = load_from_gz_file(f'{ROOT_DATA_PATH}/test-qt-fs-400-prefered-leads.pkl.gz')
# test_data = load_from_gz_file(f'{ROOT_DATA_PATH}/test-china-signal-chalenge-2020-fs-400-prefered-leads.pkl.gz')
# test_data = load_from_gz_file(f'{ROOT_DATA_PATH}/university-of-glasgow-2020-fs-400-prefered-leads.pkl.gz')

def __collate_fn(data):
    data = np.array(data, dtype=object)

    signals = np.stack(data[:, 0])
    targets = np.stack(data[:, 1])
    knowledge = np.stack(data[:, 2])

    signal_tensor = torch.from_numpy(np.expand_dims(signals , 1))
    target_tensor = torch.from_numpy(np.expand_dims(targets , 1))
    knowledge_tensor = torch.from_numpy(np.expand_dims(knowledge , 1))

    return signal_tensor, target_tensor, knowledge_tensor, data[:, 3], data[:, 4]

test_dataloader = DataLoader(ECGDataset(test_data), batch_size=16, shuffle = False, drop_last = True, collate_fn = __collate_fn)

In [None]:
# Test model
model = ECGAutoencoder(2)
model.to(device)
model.load_state_dict(torch.load(os.path.join(TRAINED_MODEL_PATH, "ECGAutoencoder-2-mit-china-qt-prefered-leads-fs-400-60epoch-CT1.pt")))
model.eval()

confusion_matrix = (0, 0, 0, 0)

for index, batch in enumerate(train_dataloader):
    print(f"Test loader {index+1}/{len(train_dataloader)}", end="\r")
    signal_tensor, target_tensor, knowledge_tensor, origin_r_peak, annotation = batch

    input_marged = torch.cat((signal_tensor, knowledge_tensor), dim=1).float()
    input_marged = input_marged.to(device)

    predicted_r_peacks = model(input_marged)
    predicted_r_peacks = predicted_r_peacks.reshape(predicted_r_peacks.shape[0],predicted_r_peacks.shape[2])

    predicted_r_peacks = prediction_post_processing(predicted_r_peacks.detach().cpu().numpy())

    targets = target_tensor.detach().cpu().numpy()
    targets = targets.reshape(targets.shape[0],targets.shape[2])

    local_accuracies, local_precisions, local_recalls, local_f1_scores, matrix = np_stat.calculate_batch_statistic(predicted_r_peacks, targets, SIGNAL_LENGTH, 10)
    confusion_matrix = tuple(matrix[i] + confusion_matrix[i] for i in range(len(confusion_matrix)))

# Print the results
overall_accuracy, overall_precision, overall_recall, overall_f1 =np_stat.calculate_statistic(confusion_matrix[0], confusion_matrix[1], confusion_matrix[2], confusion_matrix[3])

print("Accuracy:", overall_accuracy)
print("Precision:", overall_precision)
print("Recall:", overall_recall)
print("F1 Score:", overall_f1)

confusion_matrix_table = [["", "Actual Positive", "Actual Negative"],
                          ["Predicted Positive", confusion_matrix[0], confusion_matrix[2]],
                          ["Predicted Negative", confusion_matrix[3], confusion_matrix[1]]]

print(tabulate(confusion_matrix_table, tablefmt="grid"))