# Vorbereitung der Trainings-, Validierungs- und Testdaten

Obwohl in der öffentlichen Diskussion meist mehr über die neuronalen Netze, deren konkrete Architekturen, Zielfunktionen und Trainingsalgorithmen geredet wird, ist gerade im "Deep Learning" - wo man mit sehr flexiblen Modellen arbeitet - eine sorgfältige Auswahl, Qualitätskontrolle und Vorverarbeitung der Trainingsdaten *enorm* wichtig. Trainiert (oder testet/validiert) man auf den falschen Daten, können die Algorithmen noch so gut und die Rechenressourcen noch so üppig sein, das Ergebnis wird enttäuschen.

Dementsprechend ist es auch nicht verwunderlich, dass wir zunächst etwas Arbeit in die Zusammenstellung eines guten Datensatzes investieren müssen, der Thoraxröntgenbilder von Patient*innen mit einer baktieriellen oder COVID-Pneumonie, sowie unauffällige Befunde, enthalten wird.

Zunächst müssen sie dafür einige öffentlich zugängliche Datensätze auf ihren Rechner herunterladen. Diese werden wir dann in den folgenden Zellen zu einem einzigen Datensatz konsolidieren, mit dem wir dann unsere neuronalen Netze trainieren, validieren und testen können.

Bitte laden sie folgende Datensätze herunter und entpacken sie sie, am besten in einen Unterordner "raw_data" des Ordners, in dem dieses Notebook liegt:

### COVID Chest X-Ray Dataset

https://github.com/ieee8023/covid-chestxray-dataset (auf den grünen "Code"-Button klicken und dann "Download ZIP" auswählen)

### Figure1 COVID Chest X-Ray Dataset

https://github.com/agchung/Figure1-COVID-chestxray-dataset (auf den grünen "Code"-Button klicken und dann "Download ZIP" auswählen)

### Actualmed-COVID-chestxray-dataset
https://github.com/agchung/Actualmed-COVID-chestxray-dataset

### COVID19 Radiography Database
https://www.kaggle.com/tawsifurrahman/covid19-radiography-database (benötigt Registrierung bei Kaggle)

### RSNA Pneumonia Detection Challenge
https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data (benötigt Registrierung bei Kaggle)

Der folgende Code wurde adaptiert von https://github.com/lindawangg/COVID-Net/blob/master/create_COVIDx_binary.ipynb.

Falls sie die Daten wie oben angegeben in den Unterordner "covid_dataset" in diesem Verzeichnis gelegt haben, können sie die nächste Zelle direkt ausführen. Bis der Datensatz zusammengestellt ist, kann es knapp 10 Minuten dauern. Vorher sollten sie die Ausführung nicht abbrechen bzw. dieses Notebook nicht schließen. Nach Beendigung wird die benötigte Zeit, sowie die Verteilung des Datensatzes ausgegeben.

In [None]:
import numpy as np
import pandas as pd
import os
from shutil import copyfile
import pydicom as dicom
import cv2
import timeit

# set parameters here
savepath = './covid_dataset'

MAXVAL = 255  # Range [0 255]

# path to covid-19 dataset from https://github.com/ieee8023/covid-chestxray-dataset
cohen_imgpath = './raw_data/covid-chestxray-dataset/images' 
cohen_csvpath = './raw_data/covid-chestxray-dataset/metadata.csv'

# path to covid-19 dataset from https://github.com/agchung/Figure1-COVID-chestxray-dataset
fig1_imgpath = './raw_data/Figure1-COVID-chestxray-dataset/images'
fig1_csvpath = './raw_data/Figure1-COVID-chestxray-dataset/metadata.csv'

# path to covid-19 dataset from https://github.com/agchung/Actualmed-COVID-chestxray-dataset
actmed_imgpath = './raw_data/Actualmed-COVID-chestxray-dataset/images'
actmed_csvpath = './raw_data/Actualmed-COVID-chestxray-dataset/metadata.csv'

# path to covid-19 dataset from https://www.kaggle.com/tawsifurrahman/covid19-radiography-database
sirm_imgpath = './raw_data/COVID-19_Radiography_Dataset/COVID'
sirm_csvpath = './raw_data/COVID-19_Radiography_Dataset/COVID.metadata.xlsx'

# path to https://www.kaggle.com/c/rsna-pneumonia-detection-challenge
rsna_datapath = './raw_data/rsna-pneumonia-detection-challenge'
# get all the normal from here
rsna_csvname = 'stage_2_detailed_class_info.csv' 
# get all the 1s from here since 1 indicate pneumonia
# found that images that aren't pneunmonia and also not normal are classified as 0s
rsna_csvname2 = 'stage_2_train_labels.csv' 
rsna_imgpath = 'stage_2_train_images'

# parameters for COVIDx dataset
train = []

train_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}

mapping = dict()
mapping['COVID-19'] = 'COVID-19'
mapping['SARS'] = 'pneumonia'
mapping['MERS'] = 'pneumonia'
mapping['Streptococcus'] = 'pneumonia'
mapping['Klebsiella'] = 'pneumonia'
mapping['Chlamydophila'] = 'pneumonia'
mapping['Legionella'] = 'pneumonia'
mapping['E.Coli'] = 'pneumonia'
mapping['Normal'] = 'normal'
mapping['Lung Opacity'] = 'pneumonia'
mapping['1'] = 'pneumonia'

# to avoid duplicates
patient_imgpath = {}

# Record execution time
start_time = timeit.default_timer()

# adapted from https://github.com/mlmed/torchxrayvision/blob/master/torchxrayvision/datasets.py#L814
cohen_csv = pd.read_csv(cohen_csvpath, nrows=None)
#idx_pa = csv["view"] == "PA"  # Keep only the PA view
views = ["PA", "AP", "AP Supine", "AP semi erect", "AP erect"]
cohen_idx_keep = cohen_csv.view.isin(views)
cohen_csv = cohen_csv[cohen_idx_keep]

fig1_csv = pd.read_csv(fig1_csvpath, encoding='ISO-8859-1', nrows=None)
actmed_csv = pd.read_csv(actmed_csvpath, nrows=None)

sirm_csv = pd.read_excel(sirm_csvpath)

# get non-COVID19 viral, bacteria, and COVID-19 infections from covid-chestxray-dataset, figure1 and actualmed
# stored as patient id, image filename and label
filename_label = {'normal': [], 'pneumonia': [], 'COVID-19': []}
count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}
covid_ds = {'cohen': [], 'fig1': [], 'actmed': [], 'sirm': []}

for index, row in cohen_csv.iterrows():
    f = row['finding'].split('/')[-1] # take final finding in hierarchy, for the case of COVID-19, ARDS
    if f in mapping: # 
        count[mapping[f]] += 1
        entry = [str(row['patientid']), row['filename'], mapping[f], 'cohen']
        filename_label[mapping[f]].append(entry)
        if mapping[f] == 'COVID-19':
            covid_ds['cohen'].append(str(row['patientid']))
        
for index, row in fig1_csv.iterrows():
    if not str(row['finding']) == 'nan':
        f = row['finding'].split(',')[0] # take the first finding
        if f in mapping: # 
            count[mapping[f]] += 1
            if os.path.exists(os.path.join(fig1_imgpath, row['patientid'] + '.jpg')):
                entry = [row['patientid'], row['patientid'] + '.jpg', mapping[f], 'fig1']
            elif os.path.exists(os.path.join(fig1_imgpath, row['patientid'] + '.png')):
                entry = [row['patientid'], row['patientid'] + '.png', mapping[f], 'fig1']
            filename_label[mapping[f]].append(entry)
            if mapping[f] == 'COVID-19':
                covid_ds['fig1'].append(row['patientid'])

for index, row in actmed_csv.iterrows():
    if not str(row['finding']) == 'nan':
        f = row['finding'].split(',')[0]
        if f in mapping:
            count[mapping[f]] += 1
            entry = [row['patientid'], row['imagename'], mapping[f], 'actmed']
            filename_label[mapping[f]].append(entry)
            if mapping[f] == 'COVID-19':
                covid_ds['actmed'].append(row['patientid'])
    
sirm = set(sirm_csv['URL'])
cohen = set(cohen_csv['url'])
discard = ['100', '101', '102', '103', '104', '105', 
           '110', '111', '112', '113', '122', '123', 
           '124', '125', '126', '217']

for idx, row in sirm_csv.iterrows():
    patientid = row['FILE NAME']
    if row['URL'] not in cohen and patientid[patientid.find('(')+1:patientid.find(')')] not in discard:
        count[mapping['COVID-19']] += 1
        imagename = patientid + '.' + row['FORMAT'].lower()
        if not os.path.exists(os.path.join(sirm_imgpath, imagename)):
            imagename = patientid.split('(')[0] + ' ('+ patientid.split('(')[1] + '.' + row['FORMAT'].lower()
        entry = [patientid, imagename, mapping['COVID-19'], 'sirm']
        filename_label[mapping['COVID-19']].append(entry)
        covid_ds['sirm'].append(patientid)
    
print('Verteilung der COVID-Datensätze:')
print(count)

ds_imgpath = {'cohen': cohen_imgpath, 'fig1': fig1_imgpath, 'actmed': actmed_imgpath, 'sirm': sirm_imgpath}

for key in filename_label.keys():
    arr = np.array(filename_label[key])
    if arr.size == 0:
        continue

    if key == 'pneumonia':
        test_patients = ['8', '31']
    elif key == 'COVID-19':
        test_patients = ['19', '20', '36', '42', '86', 
                         '94', '97', '117', '132', 
                         '138', '144', '150', '163', '169', '174', '175', '179', '190', '191'
                         'COVID-00024', 'COVID-00025', 'COVID-00026', 'COVID-00027', 'COVID-00029',
                         'COVID-00030', 'COVID-00032', 'COVID-00033', 'COVID-00035', 'COVID-00036',
                         'COVID-00037', 'COVID-00038',
                         'ANON24', 'ANON45', 'ANON126', 'ANON106', 'ANON67',
                         'ANON153', 'ANON135', 'ANON44', 'ANON29', 'ANON201', 
                         'ANON191', 'ANON234', 'ANON110', 'ANON112', 'ANON73', 
                         'ANON220', 'ANON189', 'ANON30', 'ANON53', 'ANON46',
                         'ANON218', 'ANON240', 'ANON100', 'ANON237', 'ANON158',
                         'ANON174', 'ANON19', 'ANON195',
                         'COVID-19(119)', 'COVID-19(87)', 'COVID-19(70)', 'COVID-19(94)', 
                         'COVID-19(215)', 'COVID-19(77)', 'COVID-19(213)', 'COVID-19(81)', 
                         'COVID-19(216)', 'COVID-19(72)', 'COVID-19(106)', 'COVID-19(131)', 
                         'COVID-19(107)', 'COVID-19(116)', 'COVID-19(95)', 'COVID-19(214)', 
                         'COVID-19(129)']
    else: 
        test_patients = []
    
    # go through all the patients
    for patient in arr:
        if patient[0] not in patient_imgpath:
            patient_imgpath[patient[0]] = [patient[1]]
        else:
            if patient[1] not in patient_imgpath[patient[0]]:
                patient_imgpath[patient[0]].append(patient[1])
            else:
                continue  # skip since image has already been written
        
        if patient[3] == 'sirm':
            image = cv2.imread(os.path.join(ds_imgpath[patient[3]], patient[1]))
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            patient[1] = patient[1].replace(' ', '')
            cv2.imwrite(os.path.join(savepath, patient[1]), gray)
        else:
            copyfile(os.path.join(ds_imgpath[patient[3]], patient[1]), os.path.join(savepath, patient[1]))
        train.append(patient)
        train_count[patient[2]] += 1

# add normal and rest of pneumonia cases from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge
csv_normal = pd.read_csv(os.path.join(rsna_datapath, rsna_csvname), nrows=None)
csv_pneu = pd.read_csv(os.path.join(rsna_datapath, rsna_csvname2), nrows=None)
patients = {'normal': [], 'pneumonia': []}

for index, row in csv_normal.iterrows():
    if row['class'] == 'Normal':
        patients['normal'].append(row['patientId'])

for index, row in csv_pneu.iterrows():
    if int(row['Target']) == 1:
        patients['pneumonia'].append(row['patientId'])
        
for key in patients.keys():
    arr = np.array(patients[key])
    if arr.size == 0:
        continue

    for patient in arr:
        if patient not in patient_imgpath:
            patient_imgpath[patient] = [patient]
        else:
            continue  # skip since image has already been written
                
        ds = dicom.dcmread(os.path.join(rsna_datapath, rsna_imgpath, patient + '.dcm'))
        pixel_array_numpy = ds.pixel_array
        imgname = patient + '.png'
        cv2.imwrite(os.path.join(savepath, imgname), pixel_array_numpy)
        train.append([patient, imgname, key, 'rsna'])
        train_count[key] += 1

# final stats
print('Verteilung des Gesamtdatensatzes: ', train_count)
print('Gesamtanzahl der Trainingsbeispiele: ', len(train))

# export to train and test csv
# format as patientid, filename, label, separated by a space
train_file = open("covid_dataset.txt",'w') 
for sample in train:
    #if len(sample) == 4:
    #    info = str(sample[0]) + ' ' + sample[1] + ' ' + sample[2] + ' ' + sample[3] + '\n'
    #else:
    info = str(sample[0]) + ' ' + sample[1] + ' ' + sample[2] + '\n'
    train_file.write(info)

train_file.close()

# code you want to evaluate
elapsed = timeit.default_timer() - start_time

print('Took ' + str(elapsed) + ' seconds.')