In [None]:
import os
import sys
import h5py
import argparse
import numpy as np
from collections import Counter
import xml.etree.ElementTree as et 

# Keras imports
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils.vis_utils import model_to_dot
from tensorflow.keras.layers import Input, Dense, Dropout, AlphaDropout, BatchNormalization, Activation

# ML4CVD Imports
from ml4cvd.plots import plot_ecg
from ml4cvd.arguments import parse_args
from ml4cvd.tensor_writer_ukbb import write_tensors
from ml4cvd.recipes import train_multimodal_multitask

# IPython imports
from IPython.display import Image

In [None]:
def _to_float_or_false(s, sentinel=0):
    try:
        return float(s)
    except ValueError:
        return sentinel

stats = Counter()
xml_files = '/mnt/disks/ecg-rest-xml/'

ECG_TABLE_TAGS = ['RAmplitude'] #, 'SAmplitude']

for f in os.listdir(xml_files):
    if not f.endswith('.xml'):
        continue
    if '_20205_' not in f:
        continue
    root = et.parse(xml_files + f).getroot()
    for c in root.findall("./RestingECGMeasurements/MeasurementTable"):
        for cc in c:
            if cc.tag not in ECG_TABLE_TAGS:
                continue
            print(cc.tag, cc.attrib)
            print(list(map(_to_float_or_false, cc.text.strip().split(','))))
             
    
    for c in root.findall("./StripData/SampleRate"):
        stats['Sample Rate_' + c.text] += 1
    for c in root.findall("./StripData/Resolution"):
        stats['Resolution_' + c.text] += 1
    stats['count'] += 1
    if stats['count'] % 200 == 0:
        print(stats)
print(stats)

In [None]:
strip_ekgs = {}
ffts = {}
xml_file = '/mnt/disks/data/raw/ecgs/5223300_20205_2_0.xml'

root = et.parse(xml_file).getroot()
for c in root.findall("./StripData/WaveformData"):
    strip_ekgs[c.attrib['lead']] = np.array(list(map(float, c.text.strip().split(','))))
    ffts[c.attrib['lead']] = np.fft.fft(strip_ekgs[c.attrib['lead']])

plot_ecg(strip_ekgs, 'strip')

In [None]:
Image('./figures/strip_ecg.png')

In [None]:
for c in root.findall("./RestingECGMeasurements"):
    for child in c:
        print(child.tag, child.text)
        
for d in root.findall("./Interpretation/Diagnosis/DiagnosisText"):
	diagnosis = ''.join(e for e in d.text if e.isalnum() or e == ' ')
	print(diagnosis)

In [None]:
xml_file = '/mnt/disks/data/raw/ecgs/4856206_20205_2_0.xml'

root = et.parse(xml_file).getroot()
for c in root.findall("./StripData/WaveformData"):
	strip_ekgs[c.attrib['lead']] = np.array(list(map(float, c.text.strip().split(','))))

plot_ecg(strip_ekgs, 'strip_bradycardia')

In [None]:
Image('./figures/strip_bradycardia_ecg.png')

In [None]:
root = et.parse(xml_file).getroot()
for d in root.findall("./Interpretation/Diagnosis/DiagnosisText"):
	diagnosis = ''.join(e for e in d.text if e.isalnum() or e == ' ')
	print(diagnosis)

In [None]:
sys.argv = ['train', 
            '--tensors', '/mnt/disks/ecg-text2/2019-03-30/', 
            '--input_tensors', 'ecg_rest', 
            '--output_tensors', 'ecg_rhythm', 'ecg_normal', 'p-axis', 'p-duration', 
                'p-offset', 'p-onset', 'pp-interval', 'pq-interval', 'q-offset', 'q-onset', 
                'qrs-num', 'qrs-duration', 'r-axis', 'ventricular-rate',
            '--batch_size', '32', 
            '--epochs', '1',  
            '--learning_rate', '0.0001',
            '--model_file', '/mnt/ml4cvd/projects/jamesp/data/models/ecg_regresser.hd5',
            '--training_steps', '20',
            '--inspect_model',
            '--id', 'ecg_regresser']
args = parse_args()
train_multimodal_multitask(args)

In [None]:
Image('./recipes_output/ecg_regresser/architecture_graph_ecg_regresser.png')

In [None]:
Image('./recipes_output/ecg_regresser/per_class_roc_ecg_rhythm.png')

In [None]:
Image('./recipes_output/ecg_regresser/per_class_roc_ecg_normal.png')

In [None]:
Image('./recipes_output/ecg_regresser/scatter_PPInterval.png')

In [None]:
Image('./recipes_output/ecg_regresser/scatter_QRSDuration.png')