In [1]:
import os
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

from tqdm import tqdm
from PIL import Image
from pathlib import Path
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import MultiLabelBinarizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import Dict, List, Tuple, Union
import os
import torchvision.transforms as T

import torch
import tqdm
import pandas as pd
from lightly.transforms.multi_view_transform import MultiViewTransform

from lightly.transforms.multi_view_transform import MultiViewTransform
IMAGENET_STAT = {"mean":torch.tensor([0.4884, 0.4550, 0.4171]),
                 "std":torch.tensor([0.2596, 0.2530, 0.2556])}



train_transforms = T.Compose([T.Resize(256),
                              T.RandomHorizontalFlip(),
                              T.RandomAffine(degrees=45, scale=(.85, 1.15), shear=0, translate=(0.15, 0.15)),
                              T.CenterCrop(224),
                              T.ToTensor(),
                              T.Normalize(mean=IMAGENET_STAT["mean"],
                                          std=IMAGENET_STAT["std"])                                                                   
                            ])


val_test_transforms = T.Compose([T.Resize(256),
                                 T.CenterCrop(224),
                                 T.ToTensor(),
                                 T.Normalize(mean=IMAGENET_STAT["mean"],
                                             std=IMAGENET_STAT["std"])                                                                
                            ])

In [3]:
data_path = '/scratch/fs999/shamoutlab/data/nih_chest_xrays'

In [4]:
df = pd.read_csv(os.path.join(data_path,'Data_Entry_2017.csv'))
df.head()

Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Unnamed: 11
0,00000001_000.png,Cardiomegaly,0,1,58,M,PA,2682,2749,0.143,0.143,
1,00000001_001.png,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.143,0.143,
2,00000001_002.png,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,
3,00000002_000.png,No Finding,0,2,81,M,PA,2500,2048,0.171,0.171,
4,00000003_000.png,Hernia,0,3,81,F,PA,2582,2991,0.143,0.143,


In [5]:
df.describe()

Unnamed: 0,Follow-up #,Patient ID,Patient Age,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Unnamed: 11
count,112120.0,112120.0,112120.0,112120.0,112120.0,112120.0,112120.0,0.0
mean,8.573751,14346.381743,46.901463,2646.078844,2486.438842,0.155649,0.155649,
std,15.40632,8403.876972,16.839923,341.246429,401.268227,0.016174,0.016174,
min,0.0,1.0,1.0,1143.0,966.0,0.115,0.115,
25%,0.0,7310.75,35.0,2500.0,2048.0,0.143,0.143,
50%,3.0,13993.0,49.0,2518.0,2544.0,0.143,0.143,
75%,10.0,20673.0,59.0,2992.0,2991.0,0.168,0.168,
max,183.0,30805.0,414.0,3827.0,4715.0,0.1988,0.1988,


In [6]:
df.count()

Image Index                    112120
Finding Labels                 112120
Follow-up #                    112120
Patient ID                     112120
Patient Age                    112120
Patient Gender                 112120
View Position                  112120
OriginalImage[Width            112120
Height]                        112120
OriginalImagePixelSpacing[x    112120
y]                             112120
Unnamed: 11                         0
dtype: int64

In [7]:
df[df['Image Index'].isna()].count()

Image Index                    0
Finding Labels                 0
Follow-up #                    0
Patient ID                     0
Patient Age                    0
Patient Gender                 0
View Position                  0
OriginalImage[Width            0
Height]                        0
OriginalImagePixelSpacing[x    0
y]                             0
Unnamed: 11                    0
dtype: int64

In [8]:
len(df['Image Index'].unique()), df['Image Index'].count()


(112120, 112120)

In [9]:
df['Image Index'] = df['Image Index'].apply(lambda x: os.path.join(data_path,'images',x))

In [10]:
df[df['Finding Labels'] == 'No Finding'].count()

Image Index                    60361
Finding Labels                 60361
Follow-up #                    60361
Patient ID                     60361
Patient Age                    60361
Patient Gender                 60361
View Position                  60361
OriginalImage[Width            60361
Height]                        60361
OriginalImagePixelSpacing[x    60361
y]                             60361
Unnamed: 11                        0
dtype: int64

In [11]:
df.drop(df[df['Finding Labels'] == 'No Finding'].index, inplace=True)

In [12]:
labels = set()

for label in df['Finding Labels'].values:
    for l in label.split('|'):
        labels.add(l)
        
labels = list(labels)

In [13]:
labels = list(map(lambda x: x.lower(), labels))

labels

['nodule',
 'edema',
 'fibrosis',
 'consolidation',
 'cardiomegaly',
 'pneumonia',
 'pleural_thickening',
 'hernia',
 'effusion',
 'infiltration',
 'atelectasis',
 'pneumothorax',
 'emphysema',
 'mass']

In [14]:
len(labels)

14

In [15]:
df.head()

Unnamed: 0,Image Index,Finding Labels,Follow-up #,Patient ID,Patient Age,Patient Gender,View Position,OriginalImage[Width,Height],OriginalImagePixelSpacing[x,y],Unnamed: 11
0,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,Cardiomegaly,0,1,58,M,PA,2682,2749,0.143,0.143,
1,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,Cardiomegaly|Emphysema,1,1,58,M,PA,2894,2729,0.143,0.143,
2,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,Cardiomegaly|Effusion,2,1,58,M,PA,2500,2048,0.168,0.168,
4,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,Hernia,0,3,81,F,PA,2582,2991,0.143,0.143,
5,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,Hernia,1,3,74,F,PA,2500,2048,0.168,0.168,


In [16]:
df.drop(df.columns.difference(['Image Index', 'Finding Labels']), axis=1, inplace=True)

In [17]:
label_mlb = []
for label in df['Finding Labels'].values:
    splits = label.split('|')
    label_mlb.append(set(splits))

mlb = MultiLabelBinarizer()
label_array = mlb.fit_transform(label_mlb)

In [18]:
mlb.classes_

array(['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
       'Effusion', 'Emphysema', 'Fibrosis', 'Hernia', 'Infiltration',
       'Mass', 'Nodule', 'Pleural_Thickening', 'Pneumonia',
       'Pneumothorax'], dtype=object)

In [19]:
final_df = pd.DataFrame(label_array, columns=mlb.classes_)

In [20]:
final_df.insert(loc=0, column='Image', value=df['Image Index'].values)

In [21]:
final_df

Unnamed: 0,Image,Atelectasis,Cardiomegaly,Consolidation,Edema,Effusion,Emphysema,Fibrosis,Hernia,Infiltration,Mass,Nodule,Pleural_Thickening,Pneumonia,Pneumothorax
0,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,0,1,0,0,0,0,0,0,0,0,0,0,0,0
1,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,0,1,0,0,0,1,0,0,0,0,0,0,0,0
2,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,0,1,0,0,1,0,0,0,0,0,0,0,0,0
3,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,0,0,0,0,0,0,0,1,0,0,0,0,0,0
4,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,0,0,0,0,0,0,0,1,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51754,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,0,0,1,0,0,0,0,0,0,0,0,1,0,0
51755,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,0,0,0,0,0,0,0,0,1,0,0,0,0,0
51756,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,0,0,0,0,0,0,0,0,0,1,1,0,0,0
51757,/scratch/fs999/shamoutlab/data/nih_chest_xrays...,0,0,0,0,0,0,0,0,0,0,0,1,0,0


In [22]:
final_df.to_csv('nih-dataset.csv',index=False)

In [35]:
class NIHDataset(Dataset):
    def __init__(self, root, data_path, transform=None):
        self.root = root
        self.df = pd.read_csv('./nih-dataset.csv')
        self.transform = transform
        
        file = open(self.root)
        images = file.read().splitlines()
        
        ids = []
        
        for idx, path in enumerate(self.df['Image']):
            if path.split('/')[-1] in images:
                ids.append(idx)
        
        self.df = self.df.iloc[ids, :].reset_index(drop=True)
        self.images = self.df['Image'].values
        self.labels = self.df.iloc[:, 1:].values
        labels = list(map(lambda x: x.lower(), self.df.columns[1:]))
        self.classes = {v: k for k, v in enumerate(labels)}
        
    def __getitem__(self, item):
        img = Image.open(self.images[item]).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
            
        return img, torch.tensor(self.labels[item], dtype=torch.float32)
    
    def __len__(self):
        return len(self.df)

In [36]:
# train_dataset = NIHDataset(root=os.path.join(data_path,'train_list.txt'),
#                           df = final_df,
#                           transform=train_transforms)

In [37]:
val_dataset = NIHDataset(root=os.path.join(data_path,'val_list.txt'),
                          data_path= final_df,
                          transform=train_transforms)

In [38]:
# test_dataset = NIHDataset(root=os.path.join(data_path,'test_list.txt'),
#                           df = final_df,
#                           transform=train_transforms)

In [41]:
nih_dataloader = DataLoader(dataset=val_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=24,
                             pin_memory=True,
                             drop_last=False
                             )

In [42]:
next(iter(nih_dataloader))

[tensor([[[[-0.2650, -0.2499, -0.2499,  ..., -1.8814, -1.8814, -1.8814],
           [-0.2499, -0.2650, -0.2499,  ..., -1.8814, -1.8814, -1.8814],
           [-0.2499, -0.2499, -0.2499,  ..., -1.8814, -1.8814, -1.8814],
           ...,
           [-1.8814, -1.8814, -1.8814,  ..., -1.8814, -1.8814, -1.8814],
           [-1.8814, -1.8814, -1.8814,  ..., -1.8814, -1.8814, -1.8814],
           [-1.8814, -1.8814, -1.8814,  ..., -1.8814, -1.8814, -1.8814]],
 
          [[-0.1399, -0.1244, -0.1244,  ..., -1.7984, -1.7984, -1.7984],
           [-0.1244, -0.1399, -0.1244,  ..., -1.7984, -1.7984, -1.7984],
           [-0.1244, -0.1244, -0.1244,  ..., -1.7984, -1.7984, -1.7984],
           ...,
           [-1.7984, -1.7984, -1.7984,  ..., -1.7984, -1.7984, -1.7984],
           [-1.7984, -1.7984, -1.7984,  ..., -1.7984, -1.7984, -1.7984],
           [-1.7984, -1.7984, -1.7984,  ..., -1.7984, -1.7984, -1.7984]],
 
          [[ 0.0098,  0.0252,  0.0252,  ..., -1.6318, -1.6318, -1.6318],
           [ 