# Select a Subset of Images for Training, Validation and Testing

## Import packages

In [8]:
import pandas as pd
import numpy as np
import math
from os import walk
import os

## User inputs

In [2]:
BASE_PATH = "./labels"

# approximate target number of images for each disease
# if number of images available < NUM_ITEM_PER_CAT for a disease,
# all images associated with that disease will be selected
NUM_ITEM_PER_CAT = 2000

# target number of images for the no finding class
NUM_ITEM_NO_FINDING = 10000

# Approximate percentage of the validation images in 
# the training-validation dataset
PERCENT_VAL = 0.1
N_SPLITS = 3; # no. of training-validation splits

# Name of the training-validation file
OUT_TRAIN_VAL_LIST = 'train_val_B.csv'
# Prefix for the training and validation files 
# The split number will be added to the prefix followed by .csv
OUT_TRAIN_LIST_PREFIX = 'train_B'
OUT_VAL_LIST_PREFIX = 'val_B'

OUT_TEST_LIST = 'test_B.csv'

## Load csv file containing image labels and other details

In [3]:
dat_all = pd.read_csv(BASE_PATH + '/Data_Entry_2017_v2020.csv')
dat_all = dat_all[['Image Index', 'Finding Labels']]
dat_all.rename(columns={'Image Index': 'Image', 'Finding Labels': 'Labels'}, inplace=True)
dat_all.head()

Unnamed: 0,Image,Labels
0,00000001_000.png,Cardiomegaly
1,00000001_001.png,Cardiomegaly|Emphysema
2,00000001_002.png,Cardiomegaly|Effusion
3,00000002_000.png,No Finding
4,00000003_001.png,Hernia


## Get frequency for each disease class

In [4]:
labels = ['No Finding', 'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass', 'Nodule', 
          'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation']
#labels = ['No Finding', 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
#          'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
#          'Fibrosis', 'Pleural_Thickening', 'Hernia']

def label_stats(dat_all, labels):

    label_freq = dict(zip(labels, len(labels)*[0]))

    for label in labels:
        dat_all[label] = (dat_all['Labels'].str.contains(label)).astype('int32')
        label_freq[label] = dat_all[label].sum()

    dat_all['num_disease'] = dat_all[labels].sum(axis=1)
    # Sort the single-class disease labels by frequency from lowest to highest
    label_freq =  {k: v for k, v in sorted(label_freq.items(), key=lambda item: item[1])}
    return dat_all, label_freq

dat_all, label_freq = label_stats(dat_all, labels)

print(label_freq)
print(dat_all)


{'Hernia': 227, 'Pneumonia': 1431, 'Fibrosis': 1686, 'Edema': 2303, 'Emphysema': 2516, 'Cardiomegaly': 2776, 'Pleural_Thickening': 3385, 'Consolidation': 4667, 'Pneumothorax': 5302, 'Mass': 5782, 'Nodule': 6331, 'Atelectasis': 11559, 'Effusion': 13317, 'Infiltration': 19894, 'No Finding': 60361}
                   Image                  Labels  No Finding  Cardiomegaly  \
0       00000001_000.png            Cardiomegaly           0             1   
1       00000001_001.png  Cardiomegaly|Emphysema           0             1   
2       00000001_002.png   Cardiomegaly|Effusion           0             1   
3       00000002_000.png              No Finding           1             0   
4       00000003_001.png                  Hernia           0             0   
...                  ...                     ...         ...           ...   
112115  00030801_001.png          Mass|Pneumonia           0             0   
112116  00030802_000.png              No Finding           1             0   
1

## Select subset of all images

In [5]:
seed = 0
min_count = NUM_ITEM_PER_CAT; # target number of images
dat_all['Selected'] = False
# Select the rows in the order of lowest single disease class frequency to highest
for key, value in label_freq.items():
    if value < min_count:
        # Grab all rows associated with a label if the count associated 
        # with the label is less than min_count
        dat_all.loc[dat_all[key] == 1, 'Selected'] = True
    else:
        # Calculate number of already selected rows with label = key
        dat_all1 = dat_all.loc[(dat_all['Selected']) & (dat_all[key]==1)]
        cur_label_count = dat_all1[key].sum(axis = 0)
        deficit = min_count - cur_label_count
        if deficit > 0:
            # Get the rows with label = key and that have not been selected
            dat_all1 = dat_all.loc[(~dat_all['Selected']) & (dat_all[key]==1)]
            sel_inds = dat_all1.sample(n=deficit, axis=0, random_state=seed).index
            dat_all.loc[sel_inds, 'Selected'] = True

# Select more rows associated with a selected class
min_count2 = NUM_ITEM_NO_FINDING
label2 = 'No Finding'
dat_all1 = dat_all.loc[(dat_all['Selected']) & (dat_all[label2 ]==1)]
cur_label_count = dat_all1[label2].sum(axis = 0)
deficit = min_count2 - cur_label_count
if deficit > 0:
    # Get the rows with label = key and that have not been selected
    dat_all1 = dat_all.loc[(~dat_all['Selected']) & (dat_all[label2]==1)]
    sel_inds = dat_all1.sample(n=deficit, axis=0, random_state=seed).index
    dat_all.loc[sel_inds, 'Selected'] = True
        
dat_selected = dat_all.loc[dat_all['Selected']].copy()
dat_selected.drop(columns=['Selected', 'num_disease'], inplace=True)
dat_selected.reset_index(drop=True, inplace=True)
# Remove rows with string label count <= min_labels_count based on the column 'Labels' instead of 1-hot encoding columns
min_labels_count = 6
labels_count = dat_selected.groupby('Labels')['Labels'].count()
print('After subset select, original minimum string label count = ', labels_count.min())

# Get the string labels with count < min_labels_count
def remove_rows_w_low_label_count(dat, column, min_labels_count):
    labels_count = dat.groupby(column)[column].count()
    labels2del = list(labels_count.index[labels_count.values < min_labels_count])
    ind2del =  dat.index[dat['Labels'].isin(labels2del)]
    return dat.drop(index=ind2del)

dat_selected = remove_rows_w_low_label_count(dat_selected, 'Labels', min_labels_count)

labels_count = dat_selected.groupby('Labels')['Labels'].count()
print('After eliminating rows with low string label count, min string label count = ', labels_count.min())

single_class_counts = dat_selected.iloc[:, 2:-1].sum(axis=0)
print('Single class count: ')
print(single_class_counts)
print('Number of selected rows = ', len(dat_selected))
dat_selected.head()



After subset select, original minimum string label count =  1
After eliminating rows with low string label count, min string label count =  6
Single class count: 
No Finding            10000
Cardiomegaly           1835
Emphysema              1841
Effusion               2651
Hernia                  167
Infiltration           3124
Mass                   1739
Nodule                 1679
Atelectasis            1760
Pneumothorax           1806
Pleural_Thickening     1837
Pneumonia              1225
Fibrosis               1532
Edema                  1838
dtype: int64
Number of selected rows =  23630


Unnamed: 0,Image,Labels,No Finding,Cardiomegaly,Emphysema,Effusion,Hernia,Infiltration,Mass,Nodule,Atelectasis,Pneumothorax,Pleural_Thickening,Pneumonia,Fibrosis,Edema,Consolidation
0,00000001_001.png,Cardiomegaly|Emphysema,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0
1,00000001_002.png,Cardiomegaly|Effusion,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0
2,00000003_001.png,Hernia,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
3,00000003_002.png,Hernia,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
4,00000003_003.png,Hernia|Infiltration,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0


## Output combined training and validation list

In [6]:
train_list0 = pd.read_csv(BASE_PATH + '/train_val_list.txt', header = None, names=['Image'])
dat_train_val = dat_selected.loc[dat_selected['Image'].isin(train_list0['Image'].values)]
dat_train_val.reset_index(drop=True, inplace=True)
labels_count = dat_train_val.groupby('Labels')['Labels'].count()
print('Original train & val data min string label count = ', labels_count.min())
#labels_count.sort_values(ascending=True, inplace=True)
#print(labels_count
dat_train_val = remove_rows_w_low_label_count(dat_train_val, 'Labels', min_labels_count)
dat_train_val.reset_index(drop=True, inplace=True)
labels_count = dat_train_val.groupby('Labels')['Labels'].count()
print('Final train & val data min string label count = ', labels_count.min())
single_class_counts = dat_train_val.iloc[:, 2:-1].sum(axis=0)
print('Single class count for training and validation data: ')
print(single_class_counts)
print('Number of selected training rows = ', len(dat_train_val))
dat_train_val.to_csv(BASE_PATH + '/' + OUT_TRAIN_VAL_LIST, index=False)

Original train & val data min string label count =  1
Final train & val data min string label count =  6
Single class count for training and validation data: 
No Finding            8306
Cardiomegaly          1083
Emphysema             1022
Effusion              1554
Hernia                 101
Infiltration          1812
Mass                  1130
Nodule                1144
Atelectasis           1122
Pneumothorax           806
Pleural_Thickening    1153
Pneumonia              742
Fibrosis              1111
Edema                 1061
dtype: int64
Number of selected training rows =  16944


## Perform stratified split on train+validation list into separate train and validation list

In [7]:
# Split training validation list into training and validation sets
val_frac = PERCENT_VAL; # fraction of data assigned to validation set
n_splits = N_SPLITS

from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=val_frac, random_state=seed)
X = np.random.rand(len(dat_train_val['Labels']))
dat_train = []
dat_val = []
for split_count, (ind1, ind2) in enumerate(sss.split(X, dat_train_val['Labels'])):
    print('\nSplit ', split_count)
    #print('Training set indices = ', ind1)
    dat_train.append(dat_train_val.iloc[ind1])
    #print(dat_train[split_count].head())
    print('Number of training rows = ', len(dat_train[split_count]))
    single_class_counts = dat_train[split_count].iloc[:, 2:-1].sum(axis=0)
    print('Single class count for training data: ')
    print(single_class_counts)
    
    dat_train[split_count].to_csv(BASE_PATH + '/' + OUT_TRAIN_LIST_PREFIX + '_' + str(split_count) + '.csv',
                                 index=False)
    
    #print('Validation set indices = ', ind2)
    dat_val.append(dat_train_val.iloc[ind2])
    print('Number of validation rows = ', len(dat_val[split_count]))
    dat_val[split_count].to_csv(BASE_PATH + '/' + OUT_VAL_LIST_PREFIX  + '_' + str(split_count) + '.csv',
                                index=False)
    #print(dat_val[split_count].head())
    single_class_counts = dat_val[split_count].iloc[:, 2:-1].sum(axis=0)
    print('Single class count for validation data: ')
    print(single_class_counts)


Split  0
Number of training rows =  15249
Single class count for training data: 
No Finding            7475
Cardiomegaly           975
Emphysema              921
Effusion              1395
Hernia                  91
Infiltration          1632
Mass                  1017
Nodule                1029
Atelectasis           1008
Pneumothorax           725
Pleural_Thickening    1036
Pneumonia              668
Fibrosis               999
Edema                  954
dtype: int64
Number of validation rows =  1695
Single class count for validation data: 
No Finding            831
Cardiomegaly          108
Emphysema             101
Effusion              159
Hernia                 10
Infiltration          180
Mass                  113
Nodule                115
Atelectasis           114
Pneumothorax           81
Pleural_Thickening    117
Pneumonia              74
Fibrosis              112
Edema                 107
dtype: int64

Split  1
Number of training rows =  15249
Single class count for training 

# Output test list

In [9]:
test_list0 = pd.read_csv(BASE_PATH + '/test_list.txt', header = None, names=['Image'])
dat_test = dat_selected.loc[dat_selected['Image'].isin(test_list0['Image'].values)]
labels_count = dat_test.groupby('Labels')['Labels'].count()
print('\nTest data min string label count = ', labels_count.min())
dat_test = remove_rows_w_low_label_count(dat_test, 'Labels', min_labels_count)
labels_count = dat_train_val.groupby('Labels')['Labels'].count()
print('Final test data min string label count = ', labels_count.min())
single_class_counts = dat_test.iloc[:, 2:-1].sum(axis=0)
print('Single class count: ')
print(single_class_counts)
print('Number of selected testing rows = ', len(dat_test))
dat_test.to_csv(BASE_PATH + '/' + OUT_TEST_LIST, index=False)


Test data min string label count =  1
Final test data min string label count =  6
Single class count: 
No Finding            1694
Cardiomegaly           630
Emphysema              753
Effusion               857
Hernia                  52
Infiltration          1065
Mass                   423
Nodule                 381
Atelectasis            498
Pneumothorax           874
Pleural_Thickening     536
Pneumonia              411
Fibrosis               327
Edema                  698
dtype: int64
Number of selected testing rows =  6112
