In [1]:
import torchvision.models as models
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.linear import Linear
import torch
import skimage
#from dataset import MyDataLoader, Iterator
import tarfile, glob
import skimage
from skimage.io import imread, imsave
import pandas as pd
import numpy as np
from os.path import join
import pandas as pd
import pickle
from PIL import Image
from sklearn.preprocessing import MultiLabelBinarizer
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm
import numpy as np
from os.path import join
import pandas as pd
import pickle
from PIL import Image

from sklearn.preprocessing import MultiLabelBinarizer

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm import tqdm


In [2]:
datadir = "/lustre03/project/6008064/jpcohen/ChestXray-NIHCC"
inputsize = [224, 224]
#d = pd.read_csv("/lustre03/project/6008064/jpcohen/ChestXray-NIHCC/Data_Entry_2017.csv")

In [3]:
class NIHXrayDataset():

    def __init__(self, datadir, csvpath, transform=None, nrows=None):

        self.datadir = datadir
        self.transform = transform
        self.pathologies = ["Atelectasis", "Consolidation", "Infiltration",
                            "Pneumothorax", "Edema", "Emphysema", "Fibrosis", "Effusion", "Pneumonia",
                            "Pleural_Thickening", "Cardiomegaly", "Nodule", "Mass", "Hernia"]

        # Load data
        self.csv = pd.read_csv(csvpath, nrows=nrows)

        # Remove multi-finding images.
        #self.csv = self.csv[~self.csv["Finding Labels"].str.contains("\|")]

        # Get our two classes.
        idx_sick = self.csv["Finding Labels"].str.contains("Pneumonia")
        idx_heal = self.csv["Finding Labels"].str.contains("No Finding")
        
        # Exposed for our dataloader wrapper.
        self.csv['labels'] = 0
        self.csv.loc[idx_sick, 'labels'] = 1
        self.csv = self.csv[idx_sick | idx_heal]
        self.labels = self.csv['labels']    
        
    def __len__(self):
        return len(self.Data)

    def __getitem__(self, idx):
        
        im = misc.imread(os.path.join(self.datadir, self.csv['Image Index'][idx]))
        # For the ChestXRay dataset, range is [0, 255]

        # Check that images are 2D arrays
        if len(im.shape) > 2:
            im = im[:, :, 0]
        if len(im.shape) < 2:
            print("error, dimension lower than 2 for image", self.Data['Image Index'][idx])

        # Add color channel
        im = im[:, :, None]

        # Tranform
        if self.transform:
            im = self.transform(im)

        # self.csv['Image Index'][idx]
        return im, self.labels[idx]
    
    
class PCXRayDataset(Dataset):

    def __init__(self, datadir, csvpath, splitpath, transform=None,
                 dataset='train', pretrained=False, min_patients_per_label=50,
                 exclude_labels=["other", "normal", "no finding"], flat_dir=True):
        """
        Data reader. Only selects labels that at least min_patients_per_label patients have.
        """
        super(PCXRayDataset, self).__init__()

        assert dataset in ['train', 'val', 'test']

        self.datadir = datadir
        self.transform = transform
        self.pretrained = pretrained
        self.threshold = min_patients_per_label
        self.exclude_labels = exclude_labels
        self.flat_dir = flat_dir
        self.csv = pd.read_csv(csvpath)
        
        # Our two classes.
        idx_sick = self.csv['Labels'].str.contains('pneumonia')
        idx_sick[idx_sick.isnull()] = False
        idx_heal = self.csv['Labels'].str.contains('normal')
        idx_heal[idx_heal.isnull()] = False
                
        # Exposed for our dataloader wrapper.
        self.csv['labels'] = 0
        self.csv.loc[idx_sick, 'labels'] = 1
        self.csv = self.csv[idx_sick | idx_heal]
        self.labels = self.csv['labels'] 
        
        self.idx2pt = {idx:x for idx, x in enumerate(self.csv.PatientID.unique())}
    
    @property    
    def targets(self):
        targets = [self.metadata[pt]['Labels'] for pt in self.idx2pt.values()]
        return self.mb.transform(targets)
    
    @property
    def data(self):
        files = []
        for pt in self.idx2pt.values():
            data = self.metadata[pt]
            pa_dir = str(int(data['ImageDir']['PA'])) if not self.flat_dir else ''
            pa_path = join(self.datadir, pa_dir, data['ImageID']['PA'])
            files.append(pa_path)

        print("Reading files")
        imgs = np.stack([np.array(Image.open(path)) for path in tqdm(files)])
        imgs = np.expand_dims(imgs, -1)
        return imgs
        
    def __len__(self):
        return len(self.csv.labels)
    
    def __getitem__(self, idx):

        label = self.labels[idx]

        pa_dir = str(int(data['ImageDir']['PA'])) if not self.flat_dir else ''
        pa_path = join(self.datadir, pa_dir, data['ImageID']['PA'])
        pa_img = np.array(Image.open(pa_path))[..., np.newaxis]

        l_dir = str(int(data['ImageDir']['L'])) if not self.flat_dir else ''
        l_path = join(self.datadir, l_dir, data['ImageID']['L'])
        l_img = np.array(Image.open(l_path))[..., np.newaxis]
        
        if self.pretrained:
            # Add color channel
            pa_img = np.repeat(pa_img, 3, axis=-1)
            l_img = np.repeat(l_img, 3, axis=-1)

        sample = {'PA': pa_img, 'L': l_img}

        if self.transform is not None:
            sample = self.transform(sample)

        return sample['PA'], self.labels[idx]

    
class Normalize(object):
    """
    Changes images values to be between -1 and 1.
    """
    def __call__(self, sample):
        pa_img, l_img = sample['PA'], sample['L']

        pa_img = 2 * (pa_img / 65536) - 1.
        pa_img = pa_img.astype(np.float32)
        l_img = 2 * (l_img / 65536) - 1.
        l_img = l_img.astype(np.float32)

        sample['PA'] = pa_img
        sample['L'] = l_img
        return sample


class ToTensor(object):
    """
    Convert ndarrays in sample to Tensors.
    """
    def __call__(self, sample):
        to_tensor = transforms.ToTensor()
        sample['PA'] = to_tensor(sample['PA'])
        sample['L'] = to_tensor(sample['L'])

        return sample


class ToPILImage(object):
    """
    Convert ndarrays in sample to PIL images.
    """
    def __call__(self, sample):
        to_pil = transforms.ToPILImage()
        sample['PA'] = to_pil(sample['PA'])
        sample['L'] = to_pil(sample['L'])

        return sample


class GaussianNoise(object):
    """
    Adds Gaussian noise to the PA and L (mean 0, std 0.05)
    """
    def __call__(self, sample):
        pa_img, l_img = sample['PA'], sample['L']

        pa_img += torch.randn_like(pa_img) * 0.05
        l_img += torch.randn_like(l_img) * 0.05

        sample['PA'] = pa_img
        sample['L'] = l_img
        return sample


class RandomRotation(object):
    """
    Adds a random rotation to the PA and L (between -5 and +5).
    """
    def __call__(self, sample):
        pa_img, l_img = sample['PA'], sample['L']

        rot_amount = np.random.rand() * 5.
        rot = transforms.RandomRotation(rot_amount)
        pa_img = rot(pa_img)
        l_img = rot(l_img)

        sample['PA'] = pa_img
        sample['L'] = l_img
        return sample

In [4]:
# Test!
nih = NIHXrayDataset(None, "Data_Entry_2017.csv")
#pc = PCXRayDataset("/lustre03/project/6008064/jpcohen/PADCHEST_SJ", "/lustre03/project/6008064/jpcohen/PADCHEST_SJ/labels_csv/PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv", 
#        None, dataset='train', min_patients_per_label=10)
pc = PCXRayDataset(None, "PADCHEST_chest_x_ray_images_labels_160K_01.02.19.csv", 
        None, dataset='train', min_patients_per_label=10)    
print("sick: nih={}, pc={}".format((nih.labels == 1).sum(), (pc.labels == 1).sum()))

  if self.run_code(code, result):


sick: nih=1353, pc=8174
