In [2]:
from pymongo import MongoClient
import matplotlib.pyplot as plt

In [2]:
def connect():
    """connect to database"""
    cli = MongoClient(host="mydb", port=27017)
    db = cli.zib.ecg_data
    return db

## Get and split data

The original dataset (PTB-XL, https://pubmed.ncbi.nlm.nih.gov/32451379/) defines stratified folds. It is recommended to use folds 1-8 as training sets and 9 and 10 for validation and testing (highest confidence of diagnosis, i.e., correct labels).
See also https://pubmed.ncbi.nlm.nih.gov/32903191/

In [4]:
def get_data(db):
    """Get data.
    
    Args:
        db: MongoDB collection
    
    Returns:
        list of tuples: train, test, valid for data and labels
    """
    select = {"_id": 0, "ecg_id": 1, "rhythm_diag": 1, "data": 1, "strat_fold": 1}
    # use strat folds 1-8 for training
    data_train = list(db.find({"rhythm_diag": {"$ne": "OTHER"}, "strat_fold": {"$lte": 8}}, projection=select))
    # use strat 9 for validation
    data_val = list(db.find({"rhythm_diag": {"$ne": "OTHER"}, "strat_fold": {"$eq": 9}}, projection=select))
    # use strat 10 for validation
    data_test = list(db.find({"rhythm_diag": {"$ne": "OTHER"}, "strat_fold": {"$eq": 10}}, projection=select))

    X_train = np.array([d["data"] for d in data_train])
    X_test = np.array([d["data"] for d in data_test])
    X_val = np.array([d["data"] for d in data_val])

    label_train = [d["rhythm_diag"] for d in data_train]
    label_test = [d["rhythm_diag"] for d in data_test]
    label_val = [d["rhythm_diag"] for d in data_val]
    
    print(f"Size of training data: {X_train.shape}")
    print(f"Size of validation data: {X_val.shape}")
    print(f"Size of testing data: {X_test.shape}")

    print(f"number of samples per time series: {X_train.shape[-1]}")

    subs = {"SR": 0, "AFIB": 1}  # "OTHER": 2}
    print(f"Categories: {subs}")
    # ncat = len(set(subs.values()))
    y_train, y_test, y_val = map(lambda x: np.array([subs[i] for i in x]), (label_train, label_test, label_val))
    
    plt.figure()
    plt.hist(label_train + label_test + label_val);
    
    return (X_train, X_test, X_val), (y_train, y_test, y_val)