# MetaCheX

## Step 1: Data Pre-processing

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
import tensorflow as tf
from glob import glob
from keras.utils.np_utils import to_categorical   
from imblearn.over_sampling import RandomOverSampler
from sklearn.metrics import roc_curve
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

PATH_TO_DATA_FOLDER = './data'
NIH_IMAGES = 'nih/images'
NIH_METADATA_PATH = 'nih/Data_Entry_2017.csv'
COVID_19_RADIOGRAPHY_IMAGES = 'COVID-19_Radiography_Dataset/images' ## note labels are in the filenames
COVID_CHESTXRAY_IMAGES = 'covid-chestxray-dataset/images'
COVID_CHESTXRAY_METADATA_PATH = 'covid-chestxray-dataset/metadata.csv'

IMAGE_SIZE = 224

If labels.csv exists, read from it.

Otherwise, extract filenames and labels for:
- ChestX-ray14 (NIH) dataset
- COVID-19 Radiography Dataset
- covid-chestxray-dataset

and put in labels.csv

In [2]:
data_path = os.path.join(PATH_TO_DATA_FOLDER, 'data.pkl')
if not os.path.isfile(data_path): ## path does not exist
    df = pd.DataFrame(columns=['image_path', 'label', 'dataset'])

    ## NIH
    full_path = os.path.join(PATH_TO_DATA_FOLDER, NIH_METADATA_PATH)
    df_nih = pd.read_csv(full_path)[['Image Index', 'Finding Labels']]
    df_nih.rename(columns={'Image Index': 'image_path', 'Finding Labels': 'label'}, inplace=True)
    df_nih['label'] = df_nih['label'].str.strip().str.split('|')
    
    # --- Denotes which dataset (train/val/test) the images belong to
    df_nih_splits = pd.DataFrame(columns=['image_path', 'dataset'])
    for dataset_type in ['train', 'val', 'test']:
        sub_df_nih = pd.read_csv(os.path.join('CheXNet_data_split', f'{dataset_type}.csv'), usecols=['Image Index']).rename(columns={'Image Index': 'image_path'})
        sub_df_nih['dataset'] = dataset_type
        df_nih_splits = df_nih_splits.append(sub_df_nih)
    
    df_nih = df_nih.merge(df_nih_splits, how='left')
    ## ----- 
    
    df_nih['image_path'] = PATH_TO_DATA_FOLDER + '/' + NIH_IMAGES + '/' + df_nih['image_path']
    df = df.append(df_nih)

    ## COVID_CHESTXRAY
    full_path = os.path.join(PATH_TO_DATA_FOLDER, COVID_CHESTXRAY_METADATA_PATH)
    df_cc = pd.read_csv(full_path)[['filename', 'finding']]
    df_cc.rename(columns={'filename': 'image_path', 'finding': 'label'}, inplace=True)
    df_cc = df_cc.drop(df_cc[(df_cc['label'] == 'todo') | (df_cc['label'] == 'Unknown')].index).reset_index(drop=True)
    df_cc['label'] = df_cc['label'].str.strip().str.split('/')
    ## Remove the label after 'Pneumonia' that specifies type of pneumonia if given
    for i in range(df_cc.shape[0]):
        label = df_cc.at[i, 'label']
        if 'Pneumonia' in label and len(label) > 1:
            p_idx = label.index('Pneumonia')
            label.pop(p_idx + 1)
            #sort the labels to be in alphabetical order
            df_cc.at[i, 'label'] = sorted(label)
    
    df_cc['image_path'] = PATH_TO_DATA_FOLDER + '/' + COVID_CHESTXRAY_IMAGES + '/' + df_cc['image_path']
    df = df.append(df_cc)

    ## COVID-19 Radiography
    full_path = os.path.join(PATH_TO_DATA_FOLDER, COVID_19_RADIOGRAPHY_IMAGES)
    df_cr = pd.DataFrame(columns=['image_path', 'label'])
    image_lst = sorted(list(glob(f"{full_path}/*"))) ## gets list of all image filepaths
    label_arr = np.array([f[f.rindex('/') + 1:f.rindex('-')] for f in image_lst])
    label_arr = np.where(label_arr == 'COVID', 'COVID-19', label_arr) ## replace COVID with COVID-19 for consistency
    label_arr = np.where(label_arr == 'Viral Pneumonia', 'Pneumonia', label_arr)
    label_arr = np.where(label_arr == 'Normal', 'No Finding', label_arr) ## replace 'Normal' with 'No Finding'
    df_cr['image_path'] = image_lst
    df_cr['label'] = label_arr
    df_cr['label'] = df_cr['label'].str.strip().str.split(pat='.') ## makes each label a list (random sep so that no split on space)
    df = df.append(df_cr)

    df = df.reset_index(drop=True)
    df['label'] = df['label'].sort_values().apply(lambda x: sorted(x)) ## final sort just in case
    df['label_str'] = df['label'].str.join('|')
    df.to_pickle(data_path)
    
## Reads in csv file and adds label_num col (requires label to be str rather than list of str)
df = pd.read_pickle(data_path)

display(df.head(10))

Unnamed: 0,image_path,label,dataset,label_str
0,./data/nih/images/00000001_000.png,[Cardiomegaly],train,Cardiomegaly
1,./data/nih/images/00000001_001.png,"[Cardiomegaly, Emphysema]",train,Cardiomegaly|Emphysema
2,./data/nih/images/00000001_002.png,"[Cardiomegaly, Effusion]",train,Cardiomegaly|Effusion
3,./data/nih/images/00000002_000.png,[No Finding],train,No Finding
4,./data/nih/images/00000003_000.png,[Hernia],train,Hernia
5,./data/nih/images/00000003_001.png,[Hernia],train,Hernia
6,./data/nih/images/00000003_002.png,[Hernia],train,Hernia
7,./data/nih/images/00000003_003.png,"[Hernia, Infiltration]",train,Hernia|Infiltration
8,./data/nih/images/00000003_004.png,[Hernia],train,Hernia
9,./data/nih/images/00000003_005.png,[Hernia],train,Hernia


### Numer of unlabelled or unknown labels in covid-chestxray-dataset

In [10]:
full_path = os.path.join(PATH_TO_DATA_FOLDER, COVID_CHESTXRAY_METADATA_PATH)
df_cc = pd.read_csv(full_path)[['filename', 'finding']]
df_cc.rename(columns={'filename': 'image_path', 'finding': 'label'}, inplace=True)
df_cc = df_cc.reset_index(drop=True)
unknown = df_cc[(df_cc['label'] == 'todo') | (df_cc['label'] == 'Unknown')]
print(len(unknown))

84


In [12]:
full_path = os.path.join(PATH_TO_DATA_FOLDER, COVID_19_RADIOGRAPHY_IMAGES)
df_cr = pd.DataFrame(columns=['image_path', 'label'])
image_lst = sorted(list(glob(f"{full_path}/*"))) ## gets list of all image filepaths
label_arr = np.array([f[f.rindex('/') + 1:f.rindex('-')] for f in image_lst])
np.unique(label_arr)

array(['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia'], dtype='<U15')

Get stats on the data: 
- Number of images with each label (individual)
- Number of unique labels (individual)
- Number of labels total (including combos)

In [3]:
def get_data_stats(df):
    unique_labels_dict = {} ## keys are str
    unique_combos_dict = {} ## keys are tuples of str
    for i in range(df.shape[0]):
        labels = df.at[i, 'label']
        for l in labels:
            if l not in unique_labels_dict:
                unique_labels_dict[l] = 0
            unique_labels_dict[l] += 1

        label_str = df.at[i, 'label_str']
        if label_str not in unique_combos_dict:
            unique_combos_dict[label_str] = 0
        unique_combos_dict[label_str] += 1

    df_label_nums = pd.DataFrame.from_dict(unique_labels_dict, orient='index', columns=['count']).sort_values(by=['count'], ascending=False)
    df_combo_nums = pd.DataFrame.from_dict(unique_combos_dict, orient='index', columns=['count']).sort_values(by=['count'], ascending=False)
    print("Number of total images: ", df_label_nums['count'].sum())
    print("Number of total individual labels (includes 'No Finding'): ", df_label_nums.shape[0])
    print("Number of total label combos (includes individual labels): ", df_combo_nums.shape[0])
    print("****************************")
    print("Number of images with each individual label")
    display(df_label_nums)
    print("\n")
    print("Number of images with each combo label (Bottom 20)")
    display(df_combo_nums.tail(20))
    print("\n")

    ## Get number of labels with number of images in each range
    bins =  np.array([0, 10, 100, 1000, 10000, 100000])
    df_combo_counts = pd.DataFrame(columns=['count interval', 'number of labels'])
    df_combo_counts['count interval'] = ["< 5", "[5, 100)", "[100, 1k)", "[1k, 10k)", ">= 10k"]

    df_combo_counts['number of labels'] = [
                                   df_combo_nums[df_combo_nums['count'] < 5].size,
                                   df_combo_nums[(df_combo_nums['count'] >= 5) & (df_combo_nums['count'] < 1e2)].size,
                                   df_combo_nums[(df_combo_nums['count'] >= 1e2) & (df_combo_nums['count'] < 1e3)].size,
                                   df_combo_nums[(df_combo_nums['count'] >= 1e3) & (df_combo_nums['count'] < 1e4)].size,
                                   df_combo_nums[df_combo_nums['count'] >= 1e4].size
                                  ]
    
    display(df_combo_counts.head())
    
    return unique_labels_dict, df_combo_nums

unique_labels_dict, df_combo_nums = get_data_stats(df)

Number of total images:  164298
Number of total individual labels (includes 'No Finding'):  35
Number of total label combos (includes individual labels):  821
****************************
Number of images with each individual label


Unnamed: 0,count
No Finding,70575
Infiltration,19894
Effusion,13317
Atelectasis,11559
Nodule,6331
Lung_Opacity,6012
Mass,5782
Pneumothorax,5302
Consolidation,4667
COVID-19,4200




Number of images with each combo label (Bottom 20)


Unnamed: 0,count
Consolidation|Effusion|Infiltration|Mass|Pneumothorax,1
Atelectasis|Consolidation|Effusion|Nodule|Pneumothorax,1
Effusion|Emphysema|Infiltration|Mass,1
Atelectasis|Consolidation|Effusion|Emphysema|Mass|Pneumothorax,1
Cardiomegaly|Effusion|Pleural_Thickening|Pneumothorax,1
Atelectasis|Effusion|Fibrosis|Infiltration|Nodule,1
Effusion|Fibrosis|Mass|Pleural_Thickening,1
Cardiomegaly|Consolidation|Effusion|Infiltration|Mass|Pleural_Thickening,1
Atelectasis|Mass|Nodule|Pneumonia,1
Atelectasis|Emphysema|Fibrosis|Infiltration,1






Unnamed: 0,count interval,number of labels
0,< 5,493
1,"[5, 100)",266
2,"[100, 1k)",46
3,"[1k, 10k)",15
4,>= 10k,1


Remove all classes with less than 5 examples

In [5]:
df_combos_exclude = df_combo_nums[df_combo_nums['count'] < 5].reset_index().rename(columns={'index': 'label_str'})
display(df_combos_exclude.head(10))

df_condensed = df[~df['label_str'].isin(df_combos_exclude['label_str'])].reset_index(drop=True)
display(df_condensed.tail(10))

Unnamed: 0,label_str,count
0,Atelectasis|Emphysema|Pleural_Thickening,4
1,Atelectasis|Effusion|Mass|Pneumothorax,4
2,Infiltration|Pleural_Thickening|Pneumonia,4
3,Consolidation|Edema|Effusion|Infiltration|Mass,4
4,Consolidation|Fibrosis|Mass|Pleural_Thickening,4
5,Atelectasis|Consolidation|Edema|Effusion|Infil...,4
6,Emphysema|Fibrosis|Pneumothorax,4
7,Atelectasis|Cardiomegaly|Effusion|Mass,4
8,Atelectasis|Infiltration|Nodule|Pleural_Thicke...,4
9,Atelectasis|Effusion|Hernia,4


Unnamed: 0,image_path,label,dataset,label_str
133308,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia
133309,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia
133310,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia
133311,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia
133312,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia
133313,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia
133314,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia
133315,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia
133316,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia
133317,./data/COVID-19_Radiography_Dataset/images/Vir...,[Pneumonia],,Pneumonia


### Get updated data stats

In [6]:
print("Stats for condensed dataset")
print("---------------------------")
get_data_stats(df_condensed)

Stats for condensed dataset
---------------------------
Number of total images:  160841
Number of total individual labels (includes 'No Finding'):  28
Number of total label combos (includes individual labels):  328
****************************
Number of images with each individual label


Unnamed: 0,count
No Finding,70575
Infiltration,19510
Effusion,12915
Atelectasis,11201
Nodule,6087
Lung_Opacity,6012
Mass,5470
Pneumothorax,5062
Consolidation,4398
COVID-19,4200




Number of images with each combo label (Bottom 20)


Unnamed: 0,count
Cardiomegaly|Effusion|Emphysema|Pneumothorax,5
Pneumonia|Pneumothorax,5
Atelectasis|Consolidation|Effusion|Emphysema,5
Atelectasis|Consolidation|Effusion|Infiltration|Mass,5
Cardiomegaly|Effusion|Infiltration|Pleural_Thickening,5
Effusion|Emphysema|Nodule,5
Emphysema|Mass|Nodule,5
Effusion|Mass|Pleural_Thickening|Pneumothorax,5
Atelectasis|Effusion|Mass|Pleural_Thickening,5
Emphysema|Pneumonia,5






Unnamed: 0,count interval,number of labels
0,< 5,0
1,"[5, 100)",266
2,"[100, 1k)",46
3,"[1k, 10k)",15
4,>= 10k,1


({'Cardiomegaly': 2569,
  'Emphysema': 2330,
  'Effusion': 12915,
  'No Finding': 70575,
  'Hernia': 177,
  'Infiltration': 19510,
  'Mass': 5470,
  'Nodule': 6087,
  'Atelectasis': 11201,
  'Pneumothorax': 5062,
  'Pleural_Thickening': 3119,
  'Fibrosis': 1542,
  'Edema': 2108,
  'Consolidation': 4398,
  'Pneumonia': 3420,
  'COVID-19': 4200,
  'SARS': 16,
  'Pneumocystis': 30,
  'Streptococcus': 22,
  'Klebsiella': 10,
  'Legionella': 10,
  'Varicella': 6,
  'Mycoplasma': 11,
  'Influenza': 5,
  'Tuberculosis': 18,
  'Nocardia': 8,
  'MERS-CoV': 10,
  'Lung_Opacity': 6012},
                                                     count
 No Finding                                          70575
 Infiltration                                         9547
 Lung_Opacity                                         6012
 Atelectasis                                          4215
 Effusion                                             3955
 ...                                                   ...
 Con

### Generate labels

In [7]:
def generate_labels(df, filename, combo=True):
    path = os.path.join(PATH_TO_DATA_FOLDER, filename)
    if not os.path.isfile(path): ## path does not exist
        if combo:
            ## Get combo label (for multiclass classification)
            df['label_num_multi'] = df.groupby(['label_str']).ngroup()

        ## Get binary multi-task labels
        unique_labels = list(unique_labels_dict.keys())
        unique_labels.remove('No Finding')
        unique_labels.sort() ## alphabetical order

        df['label_multitask'] = 0
        df['label_multitask'] = df['label_multitask'].astype('object')
        for i, row in df.iterrows():
            indices = []
            for l in row['label']:
                if l == 'No Finding':
                    break

                idx = unique_labels.index(l)
                indices.append(idx)

            if indices == []:
                df.at[i, 'label_multitask'] = np.zeros(len(unique_labels)).astype(np.uint8)
            else:
                df.at[i, 'label_multitask'] = np.eye(len(unique_labels))[indices].sum(axis=0).astype(np.uint8)

        ## Save to disk
        df.to_pickle(path)
    else:
        df = pd.read_pickle(path)
    
    return df

df_condensed = generate_labels(df_condensed, 'data_condensed.pkl')
display(df_condensed.head(3))
df_condensed['label_multitask'][1]

Unnamed: 0,image_path,label,dataset,label_str,label_num_multi,label_multitask
0,./data/nih/images/00000001_000.png,[Cardiomegaly],train,Cardiomegaly,102,"[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,./data/nih/images/00000001_001.png,"[Cardiomegaly, Emphysema]",train,Cardiomegaly|Emphysema,122,"[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ..."
2,./data/nih/images/00000001_002.png,"[Cardiomegaly, Effusion]",train,Cardiomegaly|Effusion,111,"[0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ..."


array([0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=uint8)

### Dataloader

Image preprocessing

In [14]:
IMAGE_SIZE = 224
def load_and_preprocess_image(path, label):
    """
    path: path to image
    """
    try:
        image = tf.io.read_file(path)
        if path[-3:] == 'png':
            image = tf.io.decode_png(image, channels=3)
        else:
            image = tf.io.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE], method='lanczos3')
        image = image / 255 ## pixels in [0, 255] -- normalize to [0, 1]
    except ValueError:
        print('Error: ', path)
    
    return image, label

Train/val/test split: Remember that the NIH data has to be split according to how it was pre-trained

In [9]:
## Train-val-test splits
def get_ds_splits(df, split=(0.7, 0.1, 0.2)):
    """
    df: df of nih data; columns: 'image_path', 'label', 'dataset', 'label_str', 'label_multitask'
    """
    
    ## Deal with NIH datasplit first
    nih_datasets = []
    for ds_type in ['train', 'val', 'test']:
        df_nih = df[df['dataset'] == ds_type]
        ds = tf.data.Dataset.from_tensor_slices((df_nih['image_path'], 
                                                 df_nih['label_multitask'].to_list()))
        nih_datasets.append(ds)
    
    ## Non-nih data
    df_other = df[df['dataset'].isna()]
    other_count = len(df_other)
    ds = tf.data.Dataset.from_tensor_slices((df_other['image_path'], df_other['label_multitask'].to_list()))
    ds = ds.shuffle(other_count, reshuffle_each_iteration=False)
    train_count, val_count = int(other_count * split[0]), int(other_count * split[1])
    train_ds = ds.take(train_count)
    val_ds = ds.skip(train_count).take(val_count)
    test_ds = ds.skip(train_count + val_count) 
    
    other_datasets = [train_ds, val_ds, test_ds]
    
    full_datasets = []
    for i in range(3):
        ds = other_datasets[i].concatenate(nih_datasets[i])
        full_datasets.append(ds)
    
    return full_datasets

[train_ds, val_ds, test_ds] = get_ds_splits(df_condensed)

2021-11-02 07:13:00.683918: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-02 07:13:00.693321: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-02 07:13:00.694151: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-11-02 07:13:00.695695: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags

Shuffle and batch data

In [15]:
def shuffle_and_batch(ds, batch_size=32):
    ds = ds.cache()
    ds = ds.shuffle(buffer_size=1000)
    ds = ds.map(load_and_preprocess_image) ## maps the preprocessing step
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
    return ds

train_ds = shuffle_and_batch(train_ds)
val_ds = shuffle_and_batch(val_ds)

Error:  Tensor("args_0:0", shape=(), dtype=string)
Error:  Tensor("args_0:0", shape=(), dtype=string)


In [36]:
next(iter(train_ds))

2021-11-02 06:46:44.721613: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


(<tf.Tensor: shape=(32, 224, 224, 3), dtype=float32, numpy=
 array([[[[ 1.17782066e-02,  1.17782066e-02,  1.17782066e-02],
          [ 1.17904972e-02,  1.17904972e-02,  1.17904972e-02],
          [ 1.17887864e-02,  1.17887864e-02,  1.17887864e-02],
          ...,
          [ 1.19341677e-02,  1.19341677e-02,  1.19341677e-02],
          [ 1.18567599e-02,  1.18567599e-02,  1.18567599e-02],
          [ 1.17710959e-02,  1.17710959e-02,  1.17710959e-02]],
 
         [[ 1.14914663e-02,  1.14914663e-02,  1.14914663e-02],
          [ 1.14024915e-02,  1.14024915e-02,  1.14024915e-02],
          [ 1.13780200e-02,  1.13780200e-02,  1.13780200e-02],
          ...,
          [ 9.20914393e-03,  9.20914393e-03,  9.20914393e-03],
          [ 1.03910547e-02,  1.03910547e-02,  1.03910547e-02],
          [ 1.16990227e-02,  1.16990227e-02,  1.16990227e-02]],
 
         [[ 1.51060512e-02,  1.51060512e-02,  1.51060512e-02],
          [ 1.75908301e-02,  1.75908301e-02,  1.75908301e-02],
          [ 1.73001420

## Step 2: Finetuned CheXNet Baseline

Note: CheXNet = DenseNet121 trained on ChestX-ray14 dataset (multi-task binary classification)

Pre-trained weights: https://github.com/brucechou1983/CheXNet-Keras

In [None]:
def load_chexnet_pretrained(class_names=np.arange(14), weights_path='chexnet_weights.h5', 
                            input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3)):

    img_input = tf.keras.layers.Input(shape=input_shape)
    base_model = tf.keras.applications.densenet.DenseNet121(include_top=False, weights=None, 
                                                            input_tensor=img_input, pooling='avg')


    x = base_model.output
    predictions = tf.keras.layers.Dense(len(class_names), activation="sigmoid", name="predictions")(x)
    model = tf.keras.models.Model(inputs=img_input, outputs=predictions)
    model.load_weights(weights_path)

    return model


def load_chexnet(output_dim):
    """
    output_dim: dimension of output
    """
    
    base_model_old = load_chexnet_pretrained()
    x = base_model_old.layers[-2].output ## remove old prediction layer
    
    ## The prediction head can be more complicated if you want
    predictions = tf.keras.layers.Dense(output_dim, activation='softmax', name='prediction')(x)
    chexnet = tf.keras.models.Model(inputs=base_model_old.inputs, outputs=predictions)
    return chexnet
    
chexnet = load_chexnet(34)
print(chexnet.summary())

### Class balancing -- for finetuned CheXNet baseline and finetuned CheXNet w/ supervised contrastive learning
Data augmentation (minor rotations, flips) and oversampling of minority classes (classes w/ < 1k examples); undersampling majority classes (classes w/ > 10k examples)

In [7]:
## TODO

### Data split -- note that the split on NIH data must coincide with pre-trained split (or else leakage)

In [None]:
## TODO

### Train baseline -- multi-task binary classification

In [None]:
train_x, train_y, val_x, val_y, test_x, test_y = train_val_test_split(balanced_data, balanced_labels)

output_dim = ..
chexnet_ce = load(output_dim)
chexnet_ce.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

batch_size = ..
epochs = ..
chexnet_ce.fit(x=train_x, y=train_y,
              validation_data=(val_x, val_y),
              batch_size=,
              epochs=)

### Evaluate CE baseline

In [None]:
print(chexnet_ce.evaluate(test_x, test_y))