In [5]:
from lib import utils, model, dataset, training, codes

import os
import h5py

from glob import glob
from sklearn.model_selection import train_test_split
import torch
import numpy as np
import pandas as pd
import json

import os

from lib import utils, model, dataset, training, codes
import lib.datasets

In [6]:
test_metrics = {'accuracy': 0.7648809523809523, 
                'atrial fibrillation_rocauc': 0.9171844215707146, 
                'atrial fibrillation_prauc': 0.4361265410523662, 
                'incomplete right bundle branch block_rocauc': 0.9088251470467772, 
                'incomplete right bundle branch block_prauc': 0.4297674746792139, 
                'left anterior fascicular block_rocauc': 0.9268139069260974, 
                'left anterior fascicular block_prauc': 0.42671961405195336, 
                'left ventricular hypertrophy_rocauc': 0.7587642655965638, 
                'left ventricular hypertrophy_prauc': 0.2757720380334504, 
                'myocardial infarction_rocauc': 0.840214852022878, 
                'myocardial infarction_prauc': 0.618194706110927, 
                'sinus rhythm_rocauc': 0.844971891867679, 
                'sinus rhythm_prauc': 0.9638187216373598, 
                'st depression_rocauc': 0.8141247154000085, 
                'st depression_prauc': 0.18970414424421614, 
                'mean_rocaucs': 0.8586998857758169, 
                'mean_praucs': 0.47715760568706955}

zero_shot_test_metrics = {'accuracy': 0.14308608058608058, 
                          'left axis deviation_rocauc': 0.5401977190506834, 
                          'left axis deviation_prauc': 0.26244650701746625, 
                          'ventricular ectopics_rocauc': 0.7736850775542983, 
                          'ventricular ectopics_prauc': 0.14137467752900137, 
                          'myocardial ischemia_rocauc': 0.7947370443020241, 
                          'myocardial ischemia_prauc': 0.23642724160072665, 
                          't wave abnormal_rocauc': 0.6003195771799039, 
                          't wave abnormal_prauc': 0.14246834225638647, 
                          'mean_rocaucs': 0.6772348545217273, 
                          'mean_praucs': 0.1956791921008952}

exp2_metrics_trained = {'accuracy': 0.09355840124175398, 
                        'sinus arrhythmia_rocauc': 0.5895783964056902, 
                        'sinus arrhythmia_prauc': 0.07774905749756673, 
                        'left ventricular hypertrophy_rocauc': 0.8941801477243011, 
                        'left ventricular hypertrophy_prauc': 0.09694352683924748, 
                        'sinus tachycardia_rocauc': 0.9082690054453708, 
                        'sinus tachycardia_prauc': 0.5379829291033716, 
                        'atrial fibrillation_rocauc': 0.9616839712795083, 
                        'atrial fibrillation_prauc': 0.4379333718137916, 
                        'sinus bradycardia_rocauc': 0.71066942466592, 
                        'sinus bradycardia_prauc': 0.21737268705116214, 
                        'mean_rocaucs': 0.812876189104158, 
                        'mean_praucs': 0.27359631446102795}

exp2_metrics_untrained = {'accuracy': 0.5838571982925883, 
                          'incomplete right bundle-branch block_rocauc': 0.8064123614532502, 
                          'incomplete right bundle-branch block_prauc': 0.19526101242705962, 
                          'st deviation_rocauc': 0.716575594805385, 
                          'st deviation_prauc': 0.2400055776428945, 
                          't-wave abnormality_rocauc': 0.6992063931745451, 
                          't-wave abnormality_prauc': 0.17219488821487497, 
                          'right bundle-branch block_rocauc': 0.8781902764200318, 
                          'right bundle-branch block_prauc': 0.648628640444118, 
                          'prolonged pr interval_rocauc': 0.6579848060170331, 
                          'prolonged pr interval_prauc': 0.019533055742954115, 
                          'atrial premature complex(es)_rocauc': 0.6521667804330289, 
                          'atrial premature complex(es)_prauc': 0.0532194289301259, 
                          'normal ecg_rocauc': 0.7021679286919714, 
                          'normal ecg_prauc': 0.6681558074606322, 
                          'st deviation with t-wave change_rocauc': 0.7112457722285048, 
                          'st deviation with t-wave change_prauc': 0.09995219189713328, 
                          'ventricular premature complex(es)_rocauc': 0.7305275710112804, 
                          'ventricular premature complex(es)_prauc': 0.153588974502212, 
                          'low voltage_rocauc': 0.46869746319861133, 
                          'low voltage_prauc': 0.012805849968587496, 
                          'mean_rocaucs': 0.7023174947433641, 
                          'mean_praucs': 0.22633454272305925}



In [7]:
config = utils.CFG({})
config.seed = 43
config.cache_path = 'cache'
config.data_path = '/ayb/vol1/datasets/ecg_datasets/physionet.org/files/challenge-2021/1.0.3/training/ptb-xl'
config.logs_path = 'results'
config.models_path = 'results'
config.test_size = 0.2
config.valid_size = 0.25
config.min_class_count = 200
config.batch_size = 256
config.num_workers = 12
config.ecg_sr = 128
config.window = 1280
config.text_embedding_size = 768
config.projection_dim = 256
config.dropout = 0.15
config.pretrained = True
config.text_encoder_model = 'emilyalsentzer/Bio_ClinicalBERT'
config.text_tokenizer = 'emilyalsentzer/Bio_ClinicalBERT'
config.temperature = 10.0
config.head_lr = 0.0001
config.image_encoder_lr = 0.001
config.device = 'cuda:0'
config.epochs = 30
config.max_length = 200
config.ecg_encoder_model = 'ECGConvEncoder'
config.ecg_encoder_channels = [32, 32, 64, 64, 128, 128, 256, 256]
config.ecg_encoder_kernels = [7, 7, 5, 5, 3, 3, 3, 3]
config.ecg_linear_size = 512
config.ecg_embedding_size = 512
config.ecg_channels = 12
config.excluded_classes = ['abnormal QRS']
config.train_required_classes = ['normal ecg']
config.zero_shot_classes_size = 0.4

In [8]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

utils.set_seed(config.seed)
df = lib.dataset.code15.load_df()

train, test = train_test_split(df, test_size=config.test_size, random_state=config.seed)
train, valid = train_test_split(train, test_size=config.valid_size, random_state=config.seed)
 
train_classes =  utils.calsses_from_captions(train['label'].values, threshold=config.min_class_count)
valid_classes =  utils.calsses_from_captions(valid['label'].values, threshold=config.min_class_count)
test_classes = utils.calsses_from_captions(test['label'].values, threshold=config.min_class_count)

train_classes = [class_ for class_ in train_classes if class_ not in config.excluded_classes]
valid_classes = [class_ for class_ in valid_classes if class_ in train_classes]
test_classes = [class_ for class_ in test_classes if class_ in train_classes]

excluded = list()
for class_ in config.train_required_classes:
    if class_ in test_classes:
        test_classes.remove(class_)
        excluded.append(class_)
        
test_classes, zero_shot_classes = train_test_split(test_classes, test_size=config.zero_shot_classes_size, random_state=config.seed)

test_classes += excluded

train_classes = [class_ for class_ in train_classes if class_ not in zero_shot_classes]
valid_classes = [class_ for class_ in valid_classes if class_ not in zero_shot_classes]

train_classes = sorted(train_classes)
valid_classes = sorted(valid_classes)
test_classes = sorted(valid_classes)

print('Train/valid/test classes counts:', len(train_classes), len(valid_classes), len(test_classes), len(zero_shot_classes))

train['label'] = utils.remove_classes(zero_shot_classes, train['label'].to_list())


Train/valid/test classes counts: 4 4 4 3


In [None]:
config.train_classes = train_classes
config.valid_classes = valid_classes
config.test_classes = test_classes
config.zero_shot_classes = zero_shot_classes

train_ds = dataset.CLIP_ECG_Dataset(train, config)
valid_ds = dataset.CLIP_ECG_Dataset(valid, config)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=False)

net = model.CLIPModel(config)
net = net.to(config.device)
params = [
    {"params": net.image_encoder.parameters(), "lr": config.image_encoder_lr},
    {"params": net.image_projection.parameters(), "lr": config.head_lr},
    {"params": net.text_projection.parameters(), "lr": config.head_lr},
]

optimizer = torch.optim.Adam(params)

cfg = {k:v for k, v in config.__dict__.items() if not k.startswith('__')}
cfg_hash = utils.generate_dict_hash(cfg)

with open(f'{config.logs_path}/{cfg_hash}.cfg', 'w') as fp:
    json.dump(cfg, fp)

history = list()
best_valid_score = 0.0
for epoch in range(config.epochs):
    print(f"Epoch: {epoch + 1}")
    hrow = dict()
    hrow['epoch'] = epoch
    net.train()
    train_loss_meter, train_accuracy_meter = training.train_epoch(net, train_dl, optimizer, train_classes, config)
    hrow['train_loss'] = train_loss_meter.avg
    
    metrics = training.valid_epoch(net, train_dl, train_classes, config) 
    hrow.update({f'train_{key}': val for key, val in metrics.items()})
    #hrow['train_mean_rocaucs'] = np.mean([val for key, val in metrics.items() if key.endswith('_rocauc') and val is not None])
    #hrow['train_mean_praucs'] = np.mean([val for key, val in metrics.items() if key.endswith('_prauc') and val is not None])
    print('Train:', hrow['train_mean_rocaucs'], hrow['train_mean_praucs'])
    
    metrics = training.valid_epoch(net, valid_dl, valid_classes, config) 
    hrow.update({f'valid_{key}': val for key, val in metrics.items()})
    #hrow['valid_mean_rocaucs'] = np.mean([val for key, val in metrics.items() if key.endswith('_rocauc') and val is not None])
    #hrow['valid_mean_praucs'] = np.mean([val for key, val in metrics.items() if key.endswith('_prauc') and val is not None])
    print('Valid:', hrow['valid_mean_rocaucs'], hrow['valid_mean_praucs'])
    
    history.append(hrow)
    pd.DataFrame(history).to_csv(config.logs_path + f'/{cfg_hash}.csv', index=False)

    if hrow['valid_mean_rocaucs'] > best_valid_score:
        best_valid_score = hrow['valid_mean_rocaucs']
        torch.save(net.state_dict(), config.models_path + f'/{cfg_hash}.pt')      

  5%|████████▏                                                                                                                                                              | 10132/207467 [01:51<36:08, 91.02it/s]

In [None]:
net = model.CLIPModel(config)
net.load_state_dict(torch.load(config.models_path + f'/{cfg_hash}.pt', weights_only=True))
net.to(config.device)

In [None]:
test_ds = dataset.CLIP_ECG_Dataset(test, config)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True)
metrics = training.valid_epoch(net, test_dl, config.test_classes, config) 
config.test_metrics = metrics

In [None]:
metrics = training.valid_epoch(net, test_dl, config.zero_shot_classes, config) 
config.zero_shot_test_metrics = metrics

In [None]:
config

In [None]:
def remove_nonprimary_code(x):
    r = []
    for cx in x:
        for c in cx.split('+'):
            if int(c) < 200 or int(c) >= 500:
                if c not in r:
                    r.append(c)
    return r

def codes_to_caption(codes):
    classes = [description_dict[int(code)].lower() for code in codes]
    caption = ', '.join(classes)
    return caption

In [None]:
data_path = '/ayb/vol1/datasets/ecg_datasets/SPH'
ecg_files = sorted(glob(f'{data_path}/records/*.h5'))
df = pd.read_csv(f'{data_path}/metadata.csv')
df['primary_codes'] = df['AHA_Code'].str.split(';').apply(remove_nonprimary_code)
description_dict = pd.read_csv(f'{data_path}/code.csv').set_index('Code')['Description'].to_dict()
df['label'] = df['primary_codes'].apply(codes_to_caption)
df['ecg_file'] = df['ECG_ID'].apply(lambda x: f'{data_path}/records/{x}.h5')

In [None]:
df = df[['ecg_file', 'label']]

In [None]:
config.exp2_classes = utils.calsses_from_captions(df['label'].values, threshold=config.min_class_count)
config.exp2_trained_classes = list(set(config.exp2_classes) & set(config.train_classes))
config.exp2_untrained_classes = list(set(config.exp2_classes) - set(config.train_classes))

In [None]:
test_ds = dataset.CLIP_ECG_Dataset(df, config)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True)
metrics = training.valid_epoch(net, test_dl, config.exp2_trained_classes, config) 
config.exp2_metrics_trained = metrics

In [None]:
metrics = lib.training.valid_epoch(net, test_dl, config.exp2_untrained_classes, config) 
config.exp2_metrics_untrained = metrics

In [23]:
config

seed: 43
cache_path: cache
data_path: /ayb/vol1/datasets/ecg_datasets/physionet.org/files/challenge-2021/1.0.3/training/ptb-xl
logs_path: results
models_path: results
test_size: 0.2
valid_size: 0.25
min_class_count: 200
batch_size: 256
num_workers: 12
ecg_sr: 128
window: 1280
text_embedding_size: 768
projection_dim: 256
dropout: 0.15
pretrained: True
text_encoder_model: emilyalsentzer/Bio_ClinicalBERT
text_tokenizer: emilyalsentzer/Bio_ClinicalBERT
temperature: 10.0
head_lr: 0.0001
image_encoder_lr: 0.001
device: cuda:0
epochs: 30
max_length: 200
ecg_encoder_model: ECGConvEncoder
ecg_encoder_channels: [32, 32, 64, 64, 128, 128, 256, 256]
ecg_encoder_kernels: [7, 7, 5, 5, 3, 3, 3, 3]
ecg_linear_size: 512
ecg_embedding_size: 512
ecg_channels: 12
excluded_classes: ['abnormal QRS']
train_required_classes: ['normal ecg']
zero_shot_classes_size: 0.4
train_classes: ['1st degree AV block', 'normal ecg', 'right bundle branch block', 'sinus bradycardia']
valid_classes: ['1st degree AV block', 'n

In [None]:
config.exp2_metrics_trained

In [None]:
config.exp2_metrics_untrained