In [24]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, precision_recall_curve
from PIL import Image
import pandas as pd
import numpy as np
import timm
from glob import glob 
from itertools import chain
import matplotlib.pyplot as plt
from tqdm import tqdm

In [25]:
class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, image_directory, class_names, transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.image_directory = image_directory
        self.transform = transform
        self.class_names = [c.lower().replace(" ", "_") for c in class_names]
        self.class_name_to_index = {name: i for i, name in enumerate(self.class_names)}

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, index):
        row = self.dataframe.iloc[index]
        image_path = os.path.join(self.image_directory, row["Image Index"])
        image = Image.open(image_path).convert("RGB")

        label_tensor = torch.zeros(len(self.class_names))
        for label in row["Finding Labels"].split("|"):
            key = label.strip().lower().replace(" ", "_")
            if key in self.class_name_to_index:
                label_tensor[self.class_name_to_index[key]] = 1.0

        if self.transform:
            image = self.transform(image)

        return image, label_tensor

In [26]:
img_dir = "../Image/"
csv_path = "../Data/Data_Entry_2017.csv"
class_names = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
        "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass",
        "Nodule", "Pleural Thickening", "Pneumonia", "Pneumothorax"]
df = pd.read_csv(csv_path)

In [34]:
from collections import Counter
counts = Counter(
    disease
    for labels in df["Finding Labels"].str.split("|")
    for disease in labels
)
print(counts)

Counter({'No Finding': 60361, 'Infiltration': 19894, 'Effusion': 13317, 'Atelectasis': 11559, 'Nodule': 6331, 'Mass': 5782, 'Pneumothorax': 5302, 'Consolidation': 4667, 'Pleural_Thickening': 3385, 'Cardiomegaly': 2776, 'Emphysema': 2516, 'Edema': 2303, 'Fibrosis': 1686, 'Pneumonia': 1431, 'Hernia': 227})


In [35]:
def compute_weight(row):
    labels = row["Finding Labels"].split("|")
    return sum(1 / counts[l] for l in labels)

df["weight"] = df.apply(compute_weight, axis=1)

In [40]:
df = df.sample(
    n=300,
    weights="weight",
    random_state=42
)

In [41]:
counts = Counter(
    disease
    for labels in weighted_sample["Finding Labels"].str.split("|")
    for disease in labels
)
print(counts)

Counter({'Infiltration': 76, 'Effusion': 60, 'Atelectasis': 56, 'Mass': 55, 'Pleural_Thickening': 45, 'Consolidation': 35, 'Nodule': 34, 'Pneumothorax': 32, 'Edema': 31, 'Cardiomegaly': 29, 'Fibrosis': 23, 'Emphysema': 23, 'No Finding': 19, 'Hernia': 19, 'Pneumonia': 17})


In [42]:
# glob all of the images
data_image_paths = {os.path.basename(x): x for x in glob(os.path.join('..', 'Image', 'images*', '*', '*.png'))}
print('Scans found:', len(data_image_paths), ', Total Headers', df.shape[0])
df['Image Index'] = df['Image Index'].map(data_image_paths.get)
df["Finding Labels"] = df["Finding Labels"].astype(str)
all_labels = np.unique(list(chain(*df['Finding Labels'].map(lambda x: x.split('|')).tolist())))
print(all_labels)
df.head()

Scans found: 112120 , Total Headers 300
['Atelectasis' 'Cardiomegaly' 'Consolidation' 'Edema' 'Effusion'
 'Emphysema' 'Fibrosis' 'Hernia' 'Infiltration' 'Mass' 'No Finding'
 'Nodule' 'Pleural_Thickening' 'Pneumonia' 'Pneumothorax']


Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Unnamed: 11,weight
43689,../Image/images_005/images/00011251_010.png,Infiltration|Pneumonia,10,11251,76,M,AP,2500,2048,0.168,0.168,,0.000749
106852,../Image/images_012/images/00028844_009.png,Effusion|Infiltration,9,28844,65,M,AP,3056,2544,0.139,0.139,,0.000125
81005,../Image/images_009/images/00019896_004.png,Fibrosis,4,19896,61,F,PA,3056,2544,0.139,0.139,,0.000593
67334,../Image/images_008/images/00016623_001.png,Effusion|Mass,1,16623,83,M,PA,2694,2445,0.143,0.143,,0.000248
18008,../Image/images_003/images/00004832_027.png,Consolidation|Edema|Pneumothorax,27,4832,33,M,AP,2500,2048,0.168,0.168,,0.000837


In [43]:
# split the data 
unique_patient_ids = df['Patient ID'].unique()
train_val_patient_ids, test_patient_ids = train_test_split(unique_patient_ids, test_size=0.2, random_state=42)
train_patient_ids, val_patient_ids = train_test_split(train_val_patient_ids, test_size=0.25, random_state=42)

train_df = df[df['Patient ID'].isin(train_patient_ids)]
val_df = df[df['Patient ID'].isin(val_patient_ids)]
test_df = df[df['Patient ID'].isin(test_patient_ids)]

print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

Train: 177, Val: 61, Test: 62


In [44]:
weighted_sample.to_csv("../Final_Version/data/Data_Entry_2017_Sample.csv", index = False)