In [None]:
import os
import sys
import numpy as np
import argparse
import copy
import random
import json
from glob import glob
import matplotlib.pyplot as plt

#Rand Number using Numpy
from numpy.random import default_rng

#Pytorch
import torch
from torch.autograd import grad
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import torch.utils.data as data_utils

#Pillow
from PIL import Image

#XRayVision
import torchxrayvision as xrv

In [None]:
to_pil= xrv.datasets.ToPILImage()

to_resize= transforms.Resize([224, 224])

to_augment= transforms.Compose([
    transforms.RandomAffine(45, translate=(0.15, 0.15), scale=(0.85, 1.15)),
    transforms.RandomResizedCrop(224, scale=(0.7,1.0)),
    transforms.RandomHorizontalFlip(),
])

to_tensor= transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) ),
    # imagenet mean and std-dev.
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

In [None]:
def translate(img, label, domain, idx, test_case):
    
    img= to_pil(img)
    
    # No translation for Images with label 1 in case of NIH, Chex
    if label ==1:
        if domain in ['nih', 'chex'] or not test_case:
            return to_resize(img)
            
    # Translation for Images with label 0 in case of NIH, Chex
    if label ==0:
        if domain in ['kaggle'] and test_case:
            return to_resize(img)
    
    hor_shift= 0
    
    if domain == 'nih':
        ver_shift= 45
    elif domain == 'chex':
        ver_shift= 35
    elif domain == 'kaggle':
        ver_shift= 0
    
    a = 1
    b = 0
    c = hor_shift #left/right (i.e. 5/-5)
    d = 0
    e = 1
    f = -ver_shift #up/down (i.e. 5/-5)

    img_new = img.transform(img.size, Image.AFFINE, (a, b, c, d, e, f))
    size = (img_new.size[0] - hor_shift, img_new.size[1] - ver_shift)
    
    # Crop to the desired size: rectangular crop (start_horizonal, start_top, end_horizontal, end_top)
    if f>0:
        #Vertial Shift in AFFINE requires black region at bottom to be cut out with EXTENT
        img_new = img_new.transform(size, Image.EXTENT, (0, 0, size[0], size[1]))
    else:
        #Downward Shift in AFFINE requires black region at top to be cut out with EXTENT
        img_new = img_new.transform(size, Image.EXTENT, (0, ver_shift, img_new.size[0], img_new.size[1]))  
    
    #Resize, Save and Return Image
    img= to_resize(img)
    img_new= to_resize(img_new)    
    #save_img(img, img_new, idx)
    
    return img_new

In [None]:
def save_img(img, img_new, idx):
    # Save fig a with one cmap
    plt.imsave('images/figa_' + str(idx) + '.png', img, cmap='Greys_r')

    # Save fig b with a different cmap
    plt.imsave('images/figb_' + str(idx) + '.png', img_new, cmap='Greys_r')

    # Reopen fig a and fig b
    figa=plt.imread('images/figa_' + str(idx) + '.png')
    figb=plt.imread('images/figb_' + str(idx) + '.png')

    # Stitch the two figures together
    figc=np.concatenate((figa,figb),axis=1)

    # Save without a cmap, to preserve the ones you saved earlier
    plt.imsave('images/figc_' + str(idx) + '.png', figc, cmap='Greys_r')
    
    os.remove('images/figa_' + str(idx) + '.png')
    os.remove('images/figb_' + str(idx) + '.png')

In [None]:
d_nih = xrv.datasets.NIH_Dataset(imgpath="/data/datasets/NIH/images_224/",
                             views=["PA","AP"], unique_patients=False)

d_chex = xrv.datasets.CheX_Dataset(imgpath="/data/datasets/CheXpert-v1.0-small",
                               csvpath="/data/datasets/CheXpert-v1.0-small/train.csv",
                             views=["PA","AP"], unique_patients=False)

d_rsna = xrv.datasets.RSNA_Pneumonia_Dataset(imgpath="/data/datasets/Kaggle/stage_2_train_images_jpg", 
                                            views=["PA","AP"],
                                            unique_patients=False)

xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, d_nih)
xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, d_chex)
xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, d_rsna)

In [None]:
data={}
        
for data_name in ['nih', 'chex', 'kaggle']:
    
    #TestSize: 30 percent of label 1
    #TrainSize: Label 1 - TestSize
    #ValSize: 25 percent of TrainSize
    
    data[data_name]={}
    if data_name == 'nih':
        data[data_name]['obj']= d_nih
        data[data_name]['test_size']= 430 
        data[data_name]['train_size']= 800
        data[data_name]['val_size']= 200
    elif data_name == 'chex':
        data[data_name]['obj']= d_chex
        data[data_name]['test_size']= 1402
        data[data_name]['train_size']= 2618
        data[data_name]['val_size']= 654        
    elif data_name == 'kaggle':
        data[data_name]['obj']= d_rsna
        data[data_name]['test_size']= 1803
        data[data_name]['train_size']= 3368
        data[data_name]['val_size']= 841    
        
    data[data_name]['size']= len(data[data_name]['obj'])

## Translating Images:

Make directoy '/data/datasets/chestxray/' and run the following cells


In [None]:
for data_name in ['nih', 'chex', 'kaggle']:

    rng = default_rng()
    indices = rng.choice(data[data_name]['size'], size=data[data_name]['size'], replace=False)    
    print(indices.shape)
    
    count=0    
    for case in ['train', 'val', 'test']:        
        size= data[data_name][case+'_size']
        ids=[]
        
        count_l0=0
        count_l1=0
        count_lim=int(size/2)
        
        while count_l0 < count_lim or count_l1 < count_lim:

            index= indices[count].item()
            task = xrv.datasets.default_pathologies.index('Pneumonia')            
            label= data[data_name]['obj'][index]['lab'][task]
            count+=1
            
            if np.isnan(label):
                continue
            else:
                
                if label == 0:
                    if count_l0 < count_lim:
#                         print('Label 0')
                        count_l0+= 1
                    else:
                        continue

                if label ==1:
                    if count_l1 < count_lim:
#                         print('Label 1')
                        count_l1+= 1
                    else:
                        continue
                
                ids.append(index) 
                    
        ids= np.array(ids)
        
        print(count_l0, count_l1)
        base_dir='/data/datasets/chestxray/'
        np.save(base_dir + data_name + '_' + case  + '_' + 'indices.npy', ids)        

In [None]:
base_dir='/data/datasets/chestxray/'
        
for data_name in ['nih', 'chex', 'kaggle']:
    
    indices= np.random.randint(0, data[data_name]['size'], data[data_name]['size'] )
    print(indices.shape)
        
    count=0    
    for case in ['train', 'val', 'test']:
        
        size= data[data_name][case+'_size']            
        imgs=[]
        labels=[]    
        imgs_org=[]
        
        indices= np.load(base_dir + data_name + '_' + case  + '_' + 'indices.npy')        
        
        count_l0=0
        count_l1=0
        for idx in range(indices.shape[0]):
            
            index= indices[idx].item()
            task = xrv.datasets.default_pathologies.index('Pneumonia')
            
            img= data[data_name]['obj'][index]['img']
            img_org= data[data_name]['obj'][index]['img']
            label= data[data_name]['obj'][index]['lab'][task]
            count+=1
            
            if np.isnan(data[data_name]['obj'][index]['lab'][task]):
                print('Error: Nan in the labels')
                
            if label == 0:
                count_l0+=1
                
            if label == 1:
                count_l1+=1
            
            label=torch.tensor(label).long()
            label= label.view(1)
            
            img= to_tensor( to_augment( translate(img, label, data_name, index, 1) ) )                
            img_org= to_tensor( translate(img, label, data_name, index, 1) )
                
            
            img= img.view(1, img.shape[0], img.shape[1], img.shape[2])
            img_org= img_org.view(1, img_org.shape[0], img_org.shape[1], img_org.shape[2])
                                
#             print('Data: ', data_name, count, img.shape, label)
            imgs.append(img)
            imgs_org.append(img_org)
            labels.append(label)
            
            if torch.all(torch.eq(img, img_org)):
                print('Error:')
                
                
        imgs=torch.cat(imgs)
        imgs_org=torch.cat(imgs_org)
        labels=torch.cat(labels)
        print(imgs.shape, imgs_org.shape, labels.shape, count_l0, count_l1)
        
        torch.save(imgs, base_dir + data_name + '_trans_' + case  + '_' + 'image.pt')
        torch.save(imgs_org, base_dir + data_name + '_trans_' + case  + '_' + 'image_org.pt')
        torch.save(labels, base_dir + data_name + '_trans_' + case  + '_' + 'label.pt')        