In [1]:
import tarfile
import timeit
import random
import math
import re
import itertools
import datetime 
import os

from operator import itemgetter
from scipy.signal import cheby2, resample, sosfilt
from scipy import signal
from scipy.io import loadmat, savemat
from tensorly import tensor as tensor_tly
from tensorly import norm, dot
from tensorly.decomposition import tucker
from tensorly.tucker_tensor import tucker_to_tensor
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras.activations import relu
from tensorflow.keras.models import load_model, Sequential
from tensorflow.keras.layers import Dense, Dropout, Conv2D, MaxPooling2D, Flatten, BatchNormalization, Activation
from tensorflow.keras.regularizers import l2, l1_l2

import numpy as np
import pandas as pd
import ecg_plot
import seaborn as sns
import matplotlib.pyplot as plt
import logging
%matplotlib inline

2022-03-28 19:10:27.053903: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-03-28 19:10:27.053919: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
class Database(object):
  def __init__(self, path, leads):
    self.path = path
    self.leads = leads
    self.freq = 0
    self.headers_path = []
    self.recordings_path = []
    

  def extract_from_drive(self):
    if not os.path.exists(self.path):
      with tarfile.open(self.path) as zip_file:
        zip_file.extractall()
    

  def load_hea_file(self, i):
    with open(self.headers_path[i], 'r') as f:
      hea_file = f.read()
    return hea_file


  def get_frequency(self):
    header = self.load_hea_file(0)
    for i, l in enumerate(header.split('\n')):
        if i==0:
            try:
                self.freq = float(l.split(' ')[2])
            except:
                pass
        else:
            break
    print(f'Frequency: {self.freq}')


  def __str__(self):
    return f'Database(folder={self.folder}, path_drive={self.path_drive}, freq={self.freq},  leads={self.leads}, headers_paths={len(self.headers_paths)}, recordings_paths={len(self.recordings_paths)})'


  def load_paths(self):
    for f in sorted(os.listdir(self.path)):
      root, extension = os.path.splitext(f)
      if not root.startswith('.') and extension=='.hea':
        header_db_file = os.path.join(self.path, root + '.hea')
        recording_db_file = os.path.join(self.path, root + '.mat')
        if os.path.isfile(header_db_file) and os.path.isfile(recording_db_file):
          self.headers_path.append(header_db_file)
          self.recordings_path.append(recording_db_file)
    print(f'Found {len(self.headers_path)} recordings in {self.path}.')
    

class Diagnostic(object):
  diagnostics = []


  @classmethod
  def append_diagnostic(cls, diagnostic):
    cls.diagnostics.append(diagnostic)


  @classmethod
  def get_diagnostics(cls):
    return cls.diagnostics


  def __init__(self, diag_name, abbrev, code):
    self.name = diag_name
    self.abbrev = abbrev
    self.code = code
    
    

class DiagnosticDatabase(object):
  databases = []

  @classmethod
  def append_database(cls, diag_db):
    cls.databases.append(diag_db)   

    
  @classmethod
  def get_df_recordings(cls):  
    total_recs = []
    for diag_db in cls.databases:
        recs = [rec.__dict__ for rec in diag_db.recordings_diag]
        recs = [dict(rec, db=diag_db.db.path.split('/')[1], diagnostic=diag_db.diagnostic.abbrev) for rec in recs]
        total_recs.append(recs)
    total_recs = [rec for db in total_recs for rec in db]
    dataframe = pd.DataFrame(total_recs)
    return dataframe
    

  def __init__(self, diag_origin, db_origin):
    self.diagnostic = diag_origin
    self.db = db_origin
    self.headers_diag_path = []
    self.recordings_diag = []


  def get_labels(self, header):
    labels = list()
    for l in header.split('\n'):
      if l.startswith('#Dx'):
        try:
          entries = l.split(': ')[1].split(',')
          for entry in entries:
            labels.append(entry.strip())
        except:
          pass
    return labels


  def get_leads(self,header):
    leads = list()
    for i, l in enumerate(header.split('\n')):
      entries = l.split(' ')
      if i==0:
        num_leads = int(entries[1])
      elif i<=num_leads:
        leads.append(entries[-1])
      else:
        break
    return tuple(leads)


  def choose_leads(self, recording, header, leads):
    num_leads = len(leads)
    num_samples = np.shape(recording)[1]
    chosen_recording = np.zeros((num_leads, num_samples), recording.dtype)
    available_leads = self.get_leads(header)
    for i, lead in enumerate(leads):
      if lead in available_leads:
        j = available_leads.index(lead)
        chosen_recording[i, :] = recording[j, :]
    return chosen_recording


  def plot_ecg(self, index):
    ecg_plot.plot(self.rsp_cut_recordings_diag[index]/1000, sample_rate=self.db.freq/2, title='')
    ecg_plot.show()

    
class Record():
  def __init__(self, filename, inf, sup, data):
    self.filename = filename
    self.inf = inf
    self.sup = sup
    self.data = data

In [4]:
DiagnosticDatabase.databases = []
Diagnostic.diagnostics = []

af_diag = Diagnostic('Atrial Fibrilation', 'AF', '164889003')
sr_diag = Diagnostic('Sinus Rhythm', 'SR', '426783006')
Diagnostic.append_diagnostic(af_diag)
Diagnostic.append_diagnostic(sr_diag)

path_folder = '../databases/'
available_databases = os.listdir(path_folder)
#try:
    #available_databases.remove('WFDB_PTB')
    #available_databases.remove('WFDB_StPetersburg')
#except ValueError:
#    pass
available_databases = list(map(lambda db: path_folder + db, available_databases))
leads = ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6')

# Databases
for path in available_databases:
  folder = re.search('databases/(.*)', path).group(1)
  db = Database(path, leads)
  db.load_paths()
  db.get_frequency()
    
  # Diagnostics
  for diag in Diagnostic.get_diagnostics():
    diag_db = DiagnosticDatabase(diag, db)
    
    # Recordings
    for i, header_path in enumerate(diag_db.db.headers_path):
      header = diag_db.db.load_hea_file(i)
      labels = diag_db.get_labels(header)
      if diag_db.diagnostic.code in labels:
        # Get record
        rec_file = header_path.replace('hea', 'mat')
        recording = loadmat(rec_file)['val']
        recording = diag_db.choose_leads(recording, header, diag_db.db.leads)
        recording = np.array(recording, dtype=np.float64)
        
        # Filtering
        sos = signal.cheby2(12, 20, [0.35, 70], 
                            'bandpass', 
                            fs= diag_db.db.freq, 
                            output='sos')
        for lead in range(0, recording.shape[0]):
            to_filt = np.array(recording[lead, :], dtype=np.float64)
            filtered = signal.sosfilt(sos, to_filt)
            recording[lead, :] = np.array(filtered, dtype=np.float64)

        # Resample record to 250Hz
        new_freq = 250
        time_rec = len(recording[0])/diag_db.db.freq
        n_samples = int(time_rec*new_freq)
        recording = resample(recording, n_samples, axis=1)

        # Cut record in 250 samples
        interval = 250
        size_rec = recording.shape[1]
        samples_rec = math.floor(size_rec/interval)
        for i in range(0, samples_rec):
          inf = i*interval
          sup = ((i+1)*interval)
          recording_interval = np.array(list(map(lambda lead: lead[inf:sup], recording)))
          rec_filename = rec_file.split('/')[-1]
          record = Record(rec_filename, 
                          inf, 
                          sup, 
                          recording_interval)
          diag_db.recordings_diag.append(record)
        diag_db.headers_diag_path.append(header_path)
    print(f' - Found {len(diag_db.headers_diag_path)} recordings for {diag_db.diagnostic.abbrev}.')
    print(f' - Unattached {len(diag_db.recordings_diag)} intervals.')
    DiagnosticDatabase.append_database(diag_db)

Found 10344 recordings in ../databases/WFDB_Ga.
Frequency: 500.0
 - Found 570 recordings for AF.
 - Unattached 5690 intervals.
 - Found 1752 recordings for SR.
 - Unattached 17435 intervals.
Found 74 recordings in ../databases/WFDB_StPetersburg.
Frequency: 257.0
 - Found 2 recordings for AF.
 - Unattached 3600 intervals.
 - Found 0 recordings for SR.
 - Unattached 0 intervals.
Found 516 recordings in ../databases/WFDB_PTB.
Frequency: 1000.0
 - Found 15 recordings for AF.
 - Unattached 1341 intervals.
 - Found 80 recordings for SR.
 - Unattached 9435 intervals.
Found 34905 recordings in ../databases/WFDB_Ningbo.
Frequency: 500.0
 - Found 0 recordings for AF.
 - Unattached 0 intervals.
 - Found 6299 recordings for SR.
 - Unattached 62990 intervals.
Found 3453 recordings in ../databases/WFDB_CPSC2018_2.
Frequency: 500.0
 - Found 153 recordings for AF.
 - Unattached 2325 intervals.
 - Found 4 recordings for SR.
 - Unattached 61 intervals.
Found 6877 recordings in ../databases/WFDB_CPSC2018

### Statistics Database

In [5]:
df = DiagnosticDatabase.get_df_recordings()
%reset_selective -f "^DiagnosticDatabase$"
df = df.sample(frac=1)
print(f"TOTAL: {df.shape[0]} rows")
print(f"AF: {str(df[df.diagnostic == 'AF'].shape[0])} rows")
print(f"SR: {str(df[df.diagnostic == 'SR'].shape[0])} rows")
df.groupby(['diagnostic', 'db']).size()

TOTAL: 357589 rows
AF: 59132 rows
SR: 298457 rows


diagnostic  db       
AF          databases     59132
SR          databases    298457
dtype: int64

### Get balanced class database

In [6]:
# df_b = df.groupby('diagnostic')
# df_b = df_b.apply(lambda x: x.sample(df_b.size().min()))
# df_b = df_b.droplevel(level=0)
# df_b

Unnamed: 0,filename,inf,sup,data,db,diagnostic
334276,JS02349.mat,1750,2000,"[[-7.196979998059459, -17.227663931933506, -30...",databases,AF
113817,A4158.mat,1250,1500,"[[-72.71772528082523, -69.42428411249715, -69....",databases,AF
23730,I0049.mat,151250,151500,"[[-177.97506264040737, -176.25399870561907, -1...",databases,AF
24053,I0049.mat,232000,232250,"[[20.167173296569008, 22.462589312722237, 22.9...",databases,AF
24574,I0049.mat,362250,362500,"[[-88.54836200954301, -84.21898884299921, -81....",databases,AF
...,...,...,...,...,...,...
308698,HR18939.mat,2250,2500,"[[-40.16797066449827, -42.409213808409845, -44...",databases,SR
306371,HR18654.mat,500,750,"[[-51.48686351908005, -70.19114117157837, -84....",databases,SR
158424,HR00915.mat,1250,1500,"[[-36.7280623625614, -26.21199411477894, -24.8...",databases,SR
254229,HR12315.mat,0,250,"[[12.960585205688535, -41.18561151880605, -64....",databases,SR


### Persist dataframe

In [7]:
filename = str(datetime.datetime.now())
filename = filename.replace(' ', '_').replace(':', '_').replace('.', '_')
filename = '_'.join(filename.split('_')[:-1])
df_b.to_pickle(f'{filename}.pkl')  