In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from glob import glob

In [2]:
def open_cfg(file):
    with open(file, 'rt') as f:
        data = json.loads(f.read())

    for key in list(data.keys()):
        if type(data[key]) == list:
            data[key] = tuple(data[key])
    return data

def dicts_equal(dict1, dict2):
    keys1 = tuple(sorted(list(dict1.keys())))
    keys2 = tuple(sorted(list(dict2.keys())))
    if keys1 != keys2:
        return False
    for key in keys1:
        if dict1[key] != dict2[key]:
            return False
    return True

In [3]:
MODELS = ['ISIBrno_model', 'CNN_model', 'RNN_model', 'ECGConvEncoder', 'ECGConvEncoder_v2', 'ECGConvEncoder_v3']
TRAIN_DATASETS = [['ptb_xl'], ['ptb_xl', 'ningbo'], ['ptb_xl', 'ningbo', 'georgia']]
EXP2_DATASETS = ['sph', 'code15']

In [4]:
results = sorted(glob('results/*.csv'))
print(len(results))
results = {log: pd.read_csv(log) for log in results}
print(len(results))
results = {key.removeprefix('results/').removesuffix('.csv'): val for key, val in results.items()}
results = {key: {'log': val, 'cfg': open_cfg(f'results/{key}.cfg')} for key, val in results.items() if os.path.isfile(f'results/{key}.cfg')}
print(len(results))
results = {key: val for key, val in results.items() if val['cfg']['ecg_encoder_model'] in MODELS}
print(len(results))

12
12
11
11


In [5]:
for model in MODELS:
    print()
    print(model)
    model_results = {key: val for key, val in results.items() if val['cfg']['ecg_encoder_model'] == model}
    if model == 'ISIBrno_model':
        model_results = {key: val for key, val in model_results.items() if val['cfg']['device'] == 'cuda:0'}
    if model == 'ECGConvEncoder':
        model_results = {key: val for key, val in model_results.items() if val['cfg']['device'] == 'cuda:3'}
    print(len(model_results))
    if len(model_results) > 0:

        for datasets in TRAIN_DATASETS:
            print()
            #print(datasets)
            ds_results = {key: val for key, val in model_results.items() if val['cfg']['train_datasets'] == tuple(datasets)}
            #print(len(ds_results))
            if len(ds_results) > 0:
    
                for train_ds in datasets:
                    vals = [res['cfg']['test_metrics'][train_ds]['mean_rocaucs'] for res in ds_results.values()]
                    print(f'{train_ds} test metrics: {np.mean(vals):.3f}+/-{np.std(vals):.3f}')
        
                
                for train_ds in datasets:
                    vals = [res['cfg']['zero_shot_test_metrics'][train_ds]['mean_rocaucs'] for res in ds_results.values()]
                    print(f'{train_ds} zero-shot metrics: {np.mean(vals):.3f}+/-{np.std(vals):.3f}') 
        
        
                for exp2_ds in EXP2_DATASETS:
                    vals = [res['cfg']['exp2_metrics_trained'][exp2_ds]['mean_rocaucs'] for res in ds_results.values()]
                    print(f'{exp2_ds} trained metrics: {np.mean(vals):.3f}+/-{np.std(vals):.3f}')
        
        
                for exp2_ds in EXP2_DATASETS:
                    vals = [res['cfg']['exp2_metrics_untrained'][exp2_ds]['mean_rocaucs'] for res in ds_results.values()]
                    print(f'{exp2_ds} zero-shot metrics: {np.mean(vals):.3f}+/-{np.std(vals):.3f}')
                #plt.figure(figsize=(30, 5))
                #for result in ds_results.values():
                #    plt.plot(result['log']['valid_mean_rocaucs'])
                #plt.grid()
                #plt.show()


ISIBrno_model
11

ptb_xl test metrics: 0.747+/-0.059
ptb_xl zero-shot metrics: 0.574+/-0.037
sph trained metrics: 0.713+/-0.028
code15 trained metrics: 0.692+/-0.073
sph zero-shot metrics: 0.656+/-0.032
code15 zero-shot metrics: 0.597+/-0.114

ptb_xl test metrics: 0.769+/-0.023
ningbo test metrics: 0.774+/-0.027
ptb_xl zero-shot metrics: 0.638+/-0.048
ningbo zero-shot metrics: 0.546+/-0.039
sph trained metrics: 0.824+/-0.043
code15 trained metrics: 0.851+/-0.051
sph zero-shot metrics: 0.710+/-0.034
code15 zero-shot metrics: 0.727+/-0.098

ptb_xl test metrics: 0.739+/-0.056
ningbo test metrics: 0.705+/-0.100
georgia test metrics: 0.693+/-0.060
ptb_xl zero-shot metrics: 0.582+/-0.021
ningbo zero-shot metrics: 0.534+/-0.025
georgia zero-shot metrics: 0.545+/-0.035
sph trained metrics: 0.785+/-0.049
code15 trained metrics: 0.886+/-0.017
sph zero-shot metrics: 0.647+/-0.086
code15 zero-shot metrics: 0.485+/-0.055

CNN_model
0

RNN_model
0

ECGConvEncoder
0

ECGConvEncoder_v2
0

ECGConvEnco

In [6]:
for key, val in results.items():
    print('*'*80)
    print('Train classes:')
    print(len(val['cfg']['train_classes']))
    print(val['cfg']['train_classes'])
    print()
    print('Valid classes:')
    print(len(val['cfg']['valid_classes']))
    print(val['cfg']['valid_classes'])
    print()
    print('Test classes:')
    print(len(val['cfg']['test_classes']))
    print(val['cfg']['test_classes'])
    print()
    print('Zero-shot classes:')
    print(len(val['cfg']['zero_shot_classes']))
    print(val['cfg']['zero_shot_classes'])
    print()


    print('Exp2 trained classes:')
    print(len(val['cfg']['exp2_trained_classes']))
    print(val['cfg']['exp2_trained_classes'])
    print()    

    print('Exp2 untrained classes:')
    print(len(val['cfg']['exp2_untrained_classes']))
    print(val['cfg']['exp2_untrained_classes'])
    print()  

********************************************************************************
Train classes:
26
('1st degree av block', 'abnormal QRS', 'anterior myocardial infarction', 'atrial flutter', 'bundle branch block', 'complete right bundle branch block', 'early repolarization', 'incomplete right bundle branch block', 'left anterior fascicular block', 'left atrial enlargement', 'left axis deviation', 'left bundle branch block', 'left ventricular hypertrophy', 'low qrs voltages', 'myocardial infarction', 'nonspecific intraventricular conduction disorder', 'nonspecific st t abnormality', 'normal ecg', 'pacing rhythm', 'premature ventricular contractions', 'prolonged pr interval', 'prolonged qt interval', 'qwave abnormal', 'right axis deviation', 'st interval abnormal', 't wave inversion')

Valid classes:
15
('1st degree av block', 'abnormal QRS', 'atrial flutter', 'complete right bundle branch block', 'incomplete right bundle branch block', 'left anterior fascicular block', 'left axis deviat

In [7]:
import lib.datasets.ptb_xl
import lib.datasets.georgia
import lib.datasets.ningbo
import lib.datasets.sph
import lib.datasets.code15

import pandas as pd

In [8]:
def calsses_from_captions(captions, threshold=100):
    all_classes = [name.strip() for caption in captions for name in caption.strip().split(',')]
    counts = pd.Series(all_classes).value_counts()
    return counts
    classes = counts[counts >= threshold].index.to_list()
    classes = sorted(classes)
    return classes

In [73]:
datasets = ['ptb_xl', 'ningbo', 'georgia', 'sph', 'code15']
dfs = list()

for ds in datasets:
    df = lib.datasets.__dict__[ds].load_df()
    dfs.append(df)

In [96]:
classes = list()
for df in dfs:
    classes += list(calsses_from_captions(df['label'], threshold=1).index.values)
classes = sorted(list(set([class_.lower().strip() for class_ in classes])))
classes = pd.DataFrame(classes)
classes = classes.set_index(0)

In [97]:
for ds, df in zip(datasets, dfs):
    classes[f'{ds}_counts'] = calsses_from_captions(df['label'], threshold=1)
classes = classes.fillna(0).astype('int32')

In [98]:
classes

Unnamed: 0_level_0,ptb_xl_counts,ningbo_counts,georgia_counts,sph_counts,code15_counts
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
,0,132,0,0,308004
1st degree av block,797,893,769,0,0
2:1 av block,0,0,0,35,0
2nd degree av block,14,58,23,0,0
abnormal qrs,0,0,0,0,0
...,...,...,...,...,...
ventricular premature beats,0,0,357,0,0
ventricular premature complex(es),0,0,0,1067,0
ventricular trigeminy,20,0,1,0,0
wandering atrial pacemaker,0,0,7,0,0


In [99]:
cols = list()
for ds in datasets:
    classes[f'{ds}_class'] = classes.index.values
    cols.append(f'{ds}_class')
    cols.append(f'{ds}_counts')
classes = classes[cols]

In [100]:
classes = classes.reset_index(drop=True)

In [101]:
#for ds in datasets:
    #mask = classes[f'{ds}_counts'] == 0
    #classes[f'{ds}_counts'] = classes[f'{ds}_counts'].apply(str)
    #classes.loc[mask, f'{ds}_counts'] = ''
    #classes.loc[mask, f'{ds}_class'] = ''

In [102]:
classes.to_csv('data.csv')