In [1]:
import pandas as pd

import lib.iterstrat
import lib.datasets

In [2]:
class_df = pd.read_csv('classes_for_experiments_v2.csv')
class_df.head()

Unnamed: 0,class_name,ptb_xl,ningbo,georgia,sph,code15,status,rename_to
0,,0,132,0,0,308004,no eval,normal ecg
1,1st degree av block,797,893,769,0,5716,zero-shot,
2,2:1 av block,0,0,0,35,0,no eval,
3,2nd degree av block,14,58,23,0,0,no eval,
4,abnormal qrs,3389,0,0,0,0,train,


In [3]:
def fix_names(class_df, threshold=150):
    datasets = ['ptb_xl', 'ningbo', 'georgia', 'sph', 'code15']
    
    for index, rename_to in class_df['rename_to'].items():
        if not pd.isna(rename_to):
            rename_to_index = class_df[class_df['class_name'] == rename_to]
            assert rename_to_index.shape[0] == 1
            rename_to_index = rename_to_index.index.values[0]
            class_df.loc[rename_to_index, datasets] += class_df.loc[index, datasets]

    class_fixing_dict = class_df[~class_df['rename_to'].isna()].set_index('class_name')['rename_to'].to_dict()

    class_df = class_df[class_df['rename_to'].isna()]
    zeroshot_classes = class_df[class_df['status'] == 'zero-shot']
    class_df = class_df[class_df['status'] != 'zero-shot']

    train_classes_mask = (class_df['ptb_xl'] >= threshold) & (class_df['class_name'] != 'normal ecg')
    train_classes = class_df[train_classes_mask]
    noteval_classes = class_df[~train_classes_mask]
    return train_classes, noteval_classes, zeroshot_classes, class_fixing_dict


train_classes, noteval_classes, zeroshot_classes, class_fixing_dict = fix_names(class_df)

In [4]:
df = lib.datasets.ptb_xl.load_df()

In [5]:
df['dataset'] = 'ptb_xl'

In [6]:
class_fixing_dict

{nan: 'normal ecg',
 'anterior mi': 'anterior myocardial infarction',
 'atrial premature complex(es)': 'atrial premature complexes',
 'bradycardia': 'sinus bradycardia',
 'complete (third-degree)': 'complete heart block',
 'extensive anterior mi': 'anterior myocardial infarction',
 'incomplete right bundle-branch block': 'incomplete right bundle branch block',
 'junctional escape complex(es)': 'junctional escape',
 'junctional premature complex(es)': 'junctional premature complex',
 'left bundle-branch block': 'left bundle branch block',
 'left-axis deviation': 'left axis deviation',
 'low voltage': 'low qrs voltages',
 'right bundle-branch block': 'right bundle branch block',
 'right-axis deviation': 'right axis deviation',
 'second-degree av block': '2nd degree av block',
 'sinus rhythm': 'normal ecg',
 't-wave abnormality': 't wave abnormal',
 'ventricular premature complex(es)': 'ventricular premature beats'}

In [7]:
def prepare_df(df, train_classes, zeroshot_classes, class_fixing_dict):
    def fix_caption(caption):
        if caption == '':
            return 'normal ecg'
        split = caption.lower().strip().split(', ')
        fixed_caption = list()
        for class_ in split:
            class_ = class_.strip()
            if class_ in class_fixing_dict:
                class_ = class_fixing_dict[class_]
            fixed_caption.append(class_)

        if len(fixed_caption) > 1 and 'normal ecg' in fixed_caption:
            fixed_caption.remove('normal ecg')
        return ', '.join(fixed_caption)

    df['fixed_label'] = df['label'].apply(fix_caption)
    for class_ in train_classes['class_name'].values:
        df[class_] = df['fixed_label'].apply(lambda x: class_ in x)

    for class_ in zeroshot_classes['class_name'].values:
        df[class_] = df['fixed_label'].apply(lambda x: class_ in x)   

    return df

df = prepare_df(df, train_classes, zeroshot_classes, class_fixing_dict)

In [16]:
params = dict()
params['TARGETS'] = train_classes['class_name'].to_list()# + zeroshot_classes['class_name'].to_list()
params['SEED'] = 333

In [17]:
valid_mask = lib.iterstrat.short_valid_split(df, params, split=5)

In [18]:
df['split'] = valid_mask

In [19]:
for split in range(5):
    subdf = df[df['split'] == split]
    print(subdf[params['TARGETS']].sum())
    print()
    print()

abnormal qrs                                         673
anterior myocardial infarction                        77
complete right bundle branch block                   327
indeterminate cardiac axis                            22
inferior ischaemia                                    44
left anterior fascicular block                       319
left atrial enlargement                               87
left axis deviation                                 1023
left posterior fascicular block                       39
left ventricular hypertrophy                         449
low qrs voltages                                      39
myocardial infarction                               1091
myocardial ischemia                                  411
nonspecific intraventricular conduction disorder     166
nonspecific st t abnormality                          76
pacing rhythm                                         56
premature atrial contraction                          80
prolonged pr interval          