In [None]:
import numpy as np
import os
import wfdb
from collections import Counter
import pickle
import random
import sys
from tqdm import tqdm
from scipy.interpolate import interp1d

label_group_map = {'N':'N', 'L':'N', 'R':'N', 'V':'V', '/':'Q', '!':'V', 'A':'S', 'F':'F', 'f':'Q', 'j':'N', 'a':'S', 'E':'V', 'J':'S', 'e':'N', 'Q':'Q', 'S':'S', '[':'V',']':'V'}

if __name__ == "__main__":

    path = 'mit-bih-arrhythmia-database-1.0.0'
    save_path = 'data/'
    primary_lead = 'MLII'
    secondary_leads = ['V1', 'V2', 'V4', 'V5']
    num_samples = 200  # Number of samples to extract

    all_data = []
    all_data_single_lead = []
    all_group = []

    with open(os.path.join(path, 'RECORDS'), 'r') as fin:
        all_record_name = fin.read().strip().split('\n')

    for record_name in all_record_name:
        try:
            tmp_ann_res = wfdb.rdann(path + '/' + record_name, 'atr').__dict__
            tmp_data_res = wfdb.rdsamp(path + '/' + record_name)
        except:
            print('read data failed')
            continue
        fs = tmp_data_res[1]['fs']
        half_samples = num_samples // 2

        lead_in_data = tmp_data_res[1]['sig_name']
        if primary_lead in lead_in_data:
            primary_channel = lead_in_data.index(primary_lead)
            primary_data = tmp_data_res[0][:, primary_channel]

            for secondary_lead in secondary_leads:
                if secondary_lead in lead_in_data:
                    secondary_channel = lead_in_data.index(secondary_lead)
                    secondary_data = tmp_data_res[0][:, secondary_channel]

                    idx_list = list(tmp_ann_res['sample'])
                    label_list = tmp_ann_res['symbol']
                    for i in range(len(label_list)):
                        s = label_list[i]
                        if s in label_group_map.keys():
                            idx_start = idx_list[i] - half_samples
                            idx_end = idx_list[i] + half_samples
                            if idx_start < 0 or idx_end > len(primary_data):
                                continue
                            else:
                                primary_segment = primary_data[idx_start:idx_end]
                                secondary_segment = secondary_data[idx_start:idx_end]

                                combined_data = np.vstack((primary_segment, secondary_segment))
                                combined_data = np.swapaxes(combined_data, 0, 1)
                                all_data.append(combined_data)

                                single_combined = np.vstack(primary_segment)
                                all_data_single_lead.append(single_combined)
                                all_group.append(label_group_map[s])
                    print('record_name:{}, leads:{}/{}'.format(record_name, primary_lead, secondary_lead))
        else:
            print('lead in data: [{}]. primary lead {} not found in {}'.format(lead_in_data, primary_lead, record_name))
            continue

    all_data = np.array(all_data)
    all_group = np.array(all_group)
    all_data_single_lead = np.array(all_data_single_lead)
    print(all_data.shape)
    print(all_data_single_lead.shape)
    print(Counter(all_group))
    np.save(os.path.join(save_path, 'mitdb_data.npy'), all_data)
    np.save(os.path.join(save_path, 'mitdb_data_single_lead.npy'), all_data_single_lead)
    np.save(os.path.join(save_path, 'mitdb_group.npy'), all_group)



In [None]:
import tensorflow.keras as keras
from tensorflow.keras import utils
import os
import tensorflow as tf
import datetime
import numpy as np
from src import train_and_evaluate
from importlib import reload
reload(train_and_evaluate)
from src import crtnet_models
reload(crtnet_models)

def label2index(i):
    m = {'N':0, 'S':1, 'V':2, 'F':3, 'Q':4}
    return m[i]

def load_and_preprocess_data(path, num_classes):
    data = np.load(os.path.join(path, 'mitdb_data.npy'))
    label_str = np.load(os.path.join(path, 'mitdb_group.npy'))
    label = np.array([label2index(i) for i in label_str])
    one_hot = utils.to_categorical(label, num_classes=num_classes)
    return data, one_hot

path = 'data/'
num_classes = 5
class_names = ['N', 'S', 'V', 'F', 'Q']  # Update based on your classes

samples, one_hot_encoding_labels = load_and_preprocess_data(path, num_classes)

stopping = keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True)
reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, min_lr=0.00001)

model_methods = [
    crtnet_models.create_crtnet_alternate_vgg1,
   # crtnet_models.create_crtnet_original_vgg1
]

for create_crtnet_method in model_methods:
    train_and_evaluate.train_and_evaluate_model(
        create_crtnet_method,
        samples=samples,
        one_hot_encoding_labels=one_hot_encoding_labels,
        callbacks=[reduce_lr, stopping],
        is_multilabel=False,
        epochs=100,
        folds=None,
        batch_size=128,
        classes=class_names,
        initial_learning_rate=0.0001, 
        number_of_leads=samples.shape[2]
    )


