In [6]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from glob import glob

In [7]:
def open_cfg(file):
    with open(file, 'rt') as f:
        data = json.loads(f.read())

    for key in list(data.keys()):
        if type(data[key]) == list:
            data[key] = tuple(data[key])
    return data

def dicts_equal(dict1, dict2):
    keys1 = tuple(sorted(list(dict1.keys())))
    keys2 = tuple(sorted(list(dict2.keys())))
    if keys1 != keys2:
        return False
    for key in keys1:
        if dict1[key] != dict2[key]:
            return False
    return True

In [8]:
MODELS = ['ISIBrno_model', 'CNN_model', 'RNN_model', 'ECGConvEncoder', 'ECGConvEncoder_v2', 'ECGConvEncoder_v3']
TRAIN_DATASETS = [['ptb_xl'], ['ptb_xl', 'ningbo'], ['ptb_xl', 'ningbo', 'georgia']]
EXP2_DATASETS = ['sph', 'code15']

In [9]:
results = sorted(glob('results/*.csv'))
print(len(results))
results = {log: pd.read_csv(log) for log in results}
print(len(results))
results = {key.removeprefix('results/').removesuffix('.csv'): val for key, val in results.items()}
results = {key: {'log': val, 'cfg': open_cfg(f'results/{key}.cfg')} for key, val in results.items() if os.path.isfile(f'results/{key}.cfg')}
print(len(results))
results = {key: val for key, val in results.items() if val['cfg']['ecg_encoder_model'] in MODELS}
print(len(results))

5
5
4
4


In [10]:
for model in MODELS:
    print()
    print(model)
    model_results = {key: val for key, val in results.items() if val['cfg']['ecg_encoder_model'] == model}
    if model == 'ISIBrno_model':
        model_results = {key: val for key, val in model_results.items() if val['cfg']['device'] == 'cuda:0'}
    if model == 'ECGConvEncoder':
        model_results = {key: val for key, val in model_results.items() if val['cfg']['device'] == 'cuda:3'}
    print(len(model_results))
    if len(model_results) > 0:

        for datasets in TRAIN_DATASETS:
            print()
            #print(datasets)
            ds_results = {key: val for key, val in model_results.items() if val['cfg']['train_datasets'] == tuple(datasets)}
            #print(len(ds_results))
            if len(ds_results) > 0:
    
                for train_ds in datasets:
                    vals = [res['cfg']['test_metrics'][train_ds]['mean_rocaucs'] for res in ds_results.values()]
                    print(f'{train_ds} test metrics: {np.mean(vals):.3f}+/-{np.std(vals):.3f}')
        
                
                for train_ds in datasets:
                    vals = [res['cfg']['zero_shot_test_metrics'][train_ds]['mean_rocaucs'] for res in ds_results.values()]
                    print(f'{train_ds} zero-shot metrics: {np.mean(vals):.3f}+/-{np.std(vals):.3f}') 
        
        
                for exp2_ds in EXP2_DATASETS:
                    vals = [res['cfg']['exp2_metrics_trained'][exp2_ds]['mean_rocaucs'] for res in ds_results.values()]
                    print(f'{exp2_ds} trained metrics: {np.mean(vals):.3f}+/-{np.std(vals):.3f}')
        
        
                for exp2_ds in EXP2_DATASETS:
                    vals = [res['cfg']['exp2_metrics_untrained'][exp2_ds]['mean_rocaucs'] for res in ds_results.values()]
                    print(f'{exp2_ds} zero-shot metrics: {np.mean(vals):.3f}+/-{np.std(vals):.3f}')
                #plt.figure(figsize=(30, 5))
                #for result in ds_results.values():
                #    plt.plot(result['log']['valid_mean_rocaucs'])
                #plt.grid()
                #plt.show()


ISIBrno_model
4

ptb_xl test metrics: 0.785+/-0.034
ptb_xl zero-shot metrics: 0.593+/-0.008
sph trained metrics: 0.727+/-0.031
code15 trained metrics: 0.753+/-0.033
sph zero-shot metrics: 0.671+/-0.015
code15 zero-shot metrics: 0.686+/-0.058

ptb_xl test metrics: 0.775+/-0.000
ningbo test metrics: 0.739+/-0.000
ptb_xl zero-shot metrics: 0.571+/-0.000
ningbo zero-shot metrics: 0.535+/-0.000
sph trained metrics: 0.834+/-0.000
code15 trained metrics: 0.876+/-0.000
sph zero-shot metrics: 0.704+/-0.000
code15 zero-shot metrics: 0.620+/-0.000

ptb_xl test metrics: 0.668+/-0.000
ningbo test metrics: 0.583+/-0.000
georgia test metrics: 0.619+/-0.000
ptb_xl zero-shot metrics: 0.568+/-0.000
ningbo zero-shot metrics: 0.499+/-0.000
georgia zero-shot metrics: 0.562+/-0.000
sph trained metrics: 0.717+/-0.000
code15 trained metrics: 0.861+/-0.000
sph zero-shot metrics: 0.560+/-0.000
code15 zero-shot metrics: 0.430+/-0.000

CNN_model
0

RNN_model
0

ECGConvEncoder
0

ECGConvEncoder_v2
0

ECGConvEncod

In [28]:
for key, val in results.items():
    print('*'*80)
    print('Train classes:')
    print(len(val['cfg']['train_classes']))
    print(val['cfg']['train_classes'])
    print()
    print('Valid classes:')
    print(len(val['cfg']['valid_classes']))
    print(val['cfg']['valid_classes'])
    print()
    print('Test classes:')
    print(len(val['cfg']['test_classes']))
    print(val['cfg']['test_classes'])
    print()
    print('Zero-shot classes:')
    print(len(val['cfg']['zero_shot_classes']))
    print(val['cfg']['zero_shot_classes'])
    print()


    print('Exp2 trained classes:')
    print(len(val['cfg']['exp2_trained_classes']))
    print(val['cfg']['exp2_trained_classes'])
    print()    

    print('Exp2 untrained classes:')
    print(len(val['cfg']['exp2_untrained_classes']))
    print(val['cfg']['exp2_untrained_classes'])
    print()  

********************************************************************************
Train classes:
22
('1st degree av block', 'anterior myocardial infarction', 'atrial fibrillation', 'complete right bundle branch block', 'left anterior fascicular block', 'left atrial enlargement', 'left axis deviation', 'left bundle branch block', 'myocardial infarction', 'myocardial ischemia', 'nonspecific intraventricular conduction disorder', 'nonspecific st t abnormality', 'normal ecg', 'premature atrial contraction', 'prolonged pr interval', 'qwave abnormal', 'right axis deviation', 's t changes', 'sinus arrhythmia', 'sinus bradycardia', 'sinus tachycardia', 'st depression')

Valid classes:
6
('atrial fibrillation', 'left anterior fascicular block', 'left axis deviation', 'myocardial infarction', 'myocardial ischemia', 'normal ecg')

Test classes:
6
('atrial fibrillation', 'left anterior fascicular block', 'left axis deviation', 'myocardial infarction', 'myocardial ischemia', 'normal ecg')

Zero-shot