In [1]:
from lib import utils, model, dataset, training, codes

import os
import h5py

from glob import glob
from sklearn.model_selection import train_test_split
import torch
import numpy as np
import pandas as pd
import json

import os

from lib import utils, model, dataset, training, codes

In [2]:
config = utils.CFG({})
config.seed = 43
config.cache_path = 'cache'
config.data_path = '/ayb/vol1/datasets/ecg_datasets/physionet.org/files/challenge-2021/1.0.3/training/ptb-xl'
config.logs_path = 'results'
config.models_path = 'results'
config.test_size = 0.2
config.valid_size = 0.25
config.min_class_count = 200
config.batch_size = 256
config.num_workers = 12
config.ecg_sr = 128
config.window = 1280
config.text_embedding_size = 768
config.projection_dim = 256
config.dropout = 0.15
config.pretrained = True
config.text_encoder_model = 'emilyalsentzer/Bio_ClinicalBERT'
config.text_tokenizer = 'emilyalsentzer/Bio_ClinicalBERT'
config.temperature = 10.0
config.head_lr = 0.0001
config.image_encoder_lr = 0.001
config.device = 'cuda:2'
config.epochs = 30
config.max_length = 200
config.ecg_encoder_channels = [32, 32, 64, 64, 128, 128, 256, 256]
config.ecg_encoder_kernels = [7, 7, 5, 5, 3, 3, 3, 3]
config.ecg_linear_size = 512
config.ecg_embedding_size = 256
config.ecg_channels = 12
config.excluded_classes = ['abnormal QRS']
config.train_required_classes = ['sinus rhythm']
config.zero_shot_classes_size = 0.4
config.ecg_encoder_model = 'ecglib_resnet1d50'

In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

utils.set_seed(config.seed)
df = utils.get_data_cached(config.data_path, codes.DECODE_DICT, config.cache_path + '/df.csv')

train, test = train_test_split(df, test_size=config.test_size, random_state=config.seed)
train, valid = train_test_split(train, test_size=config.valid_size, random_state=config.seed)
 
train_classes =  utils.calsses_from_captions(train['label'].values, threshold=config.min_class_count)
valid_classes =  utils.calsses_from_captions(valid['label'].values, threshold=config.min_class_count)
test_classes = utils.calsses_from_captions(test['label'].values, threshold=config.min_class_count)

train_classes = [class_ for class_ in train_classes if class_ not in config.excluded_classes]
valid_classes = [class_ for class_ in valid_classes if class_ in train_classes]
test_classes = [class_ for class_ in test_classes if class_ in train_classes]

excluded = list()
for class_ in config.train_required_classes:
    if class_ in test_classes:
        test_classes.remove(class_)
        excluded.append(class_)
        
test_classes, zero_shot_classes = train_test_split(test_classes, test_size=config.zero_shot_classes_size, random_state=config.seed)

test_classes += excluded

train_classes = [class_ for class_ in train_classes if class_ not in zero_shot_classes]
valid_classes = [class_ for class_ in valid_classes if class_ not in zero_shot_classes]

train_classes = sorted(train_classes)
valid_classes = sorted(valid_classes)
test_classes = sorted(valid_classes)

print('Train/valid/test classes counts:', len(train_classes), len(valid_classes), len(test_classes), len(zero_shot_classes))

train['label'] = utils.remove_classes(zero_shot_classes, train['label'].to_list())


Train/valid/test classes counts: 21 7 7 4


In [4]:
config.train_classes = train_classes
config.valid_classes = valid_classes
config.test_classes = test_classes
config.zero_shot_classes = zero_shot_classes

train_ds = dataset.CLIP_ECG_Dataset(train, config)
valid_ds = dataset.CLIP_ECG_Dataset(valid, config)

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True)
valid_dl = torch.utils.data.DataLoader(valid_ds, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=False)

net = model.CLIPModel(config)
net = net.to(config.device)
params = [
    {"params": net.image_encoder.parameters(), "lr": config.image_encoder_lr},
    {"params": net.image_projection.parameters(), "lr": config.head_lr},
    {"params": net.text_projection.parameters(), "lr": config.head_lr},
]

optimizer = torch.optim.Adam(params)

cfg = {k:v for k, v in config.__dict__.items() if not k.startswith('__')}
cfg_hash = utils.generate_dict_hash(cfg)

with open(f'{config.logs_path}/{cfg_hash}.cfg', 'w') as fp:
    json.dump(cfg, fp)

history = list()
best_valid_score = 0.0
for epoch in range(config.epochs):
    print(f"Epoch: {epoch + 1}")
    hrow = dict()
    hrow['epoch'] = epoch
    net.train()
    train_loss_meter, train_accuracy_meter = training.train_epoch(net, train_dl, optimizer, train_classes, config)
    hrow['train_loss'] = train_loss_meter.avg
    
    metrics = training.valid_epoch(net, train_dl, train_classes, config) 
    hrow.update({f'train_{key}': val for key, val in metrics.items()})
    #hrow['train_mean_rocaucs'] = np.mean([val for key, val in metrics.items() if key.endswith('_rocauc') and val is not None])
    #hrow['train_mean_praucs'] = np.mean([val for key, val in metrics.items() if key.endswith('_prauc') and val is not None])
    print('Train:', hrow['train_mean_rocaucs'], hrow['train_mean_praucs'])
    
    metrics = training.valid_epoch(net, valid_dl, valid_classes, config) 
    hrow.update({f'valid_{key}': val for key, val in metrics.items()})
    #hrow['valid_mean_rocaucs'] = np.mean([val for key, val in metrics.items() if key.endswith('_rocauc') and val is not None])
    #hrow['valid_mean_praucs'] = np.mean([val for key, val in metrics.items() if key.endswith('_prauc') and val is not None])
    print('Valid:', hrow['valid_mean_rocaucs'], hrow['valid_mean_praucs'])
    
    history.append(hrow)
    pd.DataFrame(history).to_csv(config.logs_path + f'/{cfg_hash}.csv', index=False)

    if hrow['valid_mean_rocaucs'] > best_valid_score:
        best_valid_score = hrow['valid_mean_rocaucs']
        torch.save(net.state_dict(), config.models_path + f'/{cfg_hash}.pt')      

100%|███████████████████████████████████| 13101/13101 [00:01<00:00, 6742.83it/s]
100%|█████████████████████████████████████| 4368/4368 [00:00<00:00, 8211.92it/s]
Downloading: "https://github.com/ispras/EcgLib/releases/download/v1.1.0/12_leads_resnet1d18_1AVB.pt" to /home/kyegorov/.cache/torch/hub/checkpoints/12_leads_resnet1d18_1AVB_1_1_0.pt


Epoch: 1


  0%|                                                    | 0/52 [00:04<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x1 and 1024x256)

In [None]:
net = model.CLIPModel(config)
net.load_state_dict(torch.load(config.models_path + f'/{cfg_hash}.pt', weights_only=True))
net.to(config.device)

In [None]:
test_ds = dataset.CLIP_ECG_Dataset(test, config)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True)
metrics = training.valid_epoch(net, test_dl, config.test_classes, config) 
config.test_metrics = metrics

In [None]:
metrics = training.valid_epoch(net, test_dl, config.zero_shot_classes, config) 
config.zero_shot_test_metrics = metrics

In [None]:
config

In [None]:
def remove_nonprimary_code(x):
    r = []
    for cx in x:
        for c in cx.split('+'):
            if int(c) < 200 or int(c) >= 500:
                if c not in r:
                    r.append(c)
    return r

def codes_to_caption(codes):
    classes = [description_dict[int(code)].lower() for code in codes]
    caption = ', '.join(classes)
    return caption

In [None]:
data_path = '/ayb/vol1/datasets/ecg_datasets/SPH'
ecg_files = sorted(glob(f'{data_path}/records/*.h5'))
df = pd.read_csv(f'{data_path}/metadata.csv')
df['primary_codes'] = df['AHA_Code'].str.split(';').apply(remove_nonprimary_code)
description_dict = pd.read_csv(f'{data_path}/code.csv').set_index('Code')['Description'].to_dict()
df['label'] = df['primary_codes'].apply(codes_to_caption)
df['ecg_file'] = df['ECG_ID'].apply(lambda x: f'{data_path}/records/{x}.h5')

In [None]:
df = df[['ecg_file', 'label']]

In [None]:
config.exp2_classes = utils.calsses_from_captions(df['label'].values, threshold=config.min_class_count)
config.exp2_trained_classes = list(set(config.exp2_classes) & set(config.train_classes))
config.exp2_untrained_classes = list(set(config.exp2_classes) - set(config.train_classes))

In [None]:
test_ds = dataset.CLIP_ECG_Dataset(df, config)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True)
metrics = training.valid_epoch(net, test_dl, config.exp2_trained_classes, config) 
config.exp2_metrics_trained = metrics

In [None]:
metrics = training.valid_epoch(net, test_dl, config.exp2_untrained_classes, config) 
config.exp2_metrics_untrained = metrics

In [None]:
config

In [None]:
config.exp2_metrics_trained

In [None]:
config.exp2_metrics_untrained