# Make the training dataset for DONGRONG NET

### Train data and val data are put into a directory like this:
data

 -- train
 ---- COVID
 ---- NORMAL
 ---- PNEUMONIA
 
 -- test
 ---- COVID
 ---- NORMAL
 ---- PNEUMONIA
 
 -- val
 ---- COVID
 ---- NORMAL
 ---- PNEUMONIA

In [1]:
# RICORD 1C has 100% COVID images
# RSNA dataset has either NORMAL or PNEUMONIA images

# RSNA
# RSNA is already split into test/train/val with train & val containing 100 patients each

import pydicom as dicom
import os, sys, datetime, random, math
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import glob

# RICORD
# Patient dates have been hash-dated during anonymisation (https://www.rsna.org/-/media/Files/RSNA/Covid-19/RICORD/RSNA-Covid-19-Deidentification-Protocol.pdf)
# Just pull 1 patient's data out at a time 

flag_save = True
out_dir = Path("D:/data/RICORD-1c/ricord_trainValTest/")

# RICORD metadata file paths
ricord_dir = 'D:/data/RICORD-1c/MIDRC-RICORD-1C'
ricord_meta_file = 'D:/data/RICORD-1c/MIDRC-RICORD-1c Clinical Data Jan 13 2021 .xls'
ricord_set_file = 'D:/data/RICORD-1c/ricord_data_set.txt'

In [2]:
# Source code is from Linda Wang's COVID-Net code
"""Helper functions"""
def _extract_data(df_row):
    return df_row['Anon MRN'], df_row['Anon TCIA Study Date'], df_row['Anon Exam Description'], df_row['Anon Study UID']

def load_ricord_metadata(ricord_meta_file):
    df = pd.read_excel(ricord_meta_file, sheet_name='CR Pos - TCIA Submission')
    ricord_metadata = []
    for index, row in df.iterrows():
        ricord_metadata.append(_extract_data(row))
    return ricord_metadata

def make_ricord_dict(ricord_data_set_file):
    """Loads bboxes from the given text file"""
    ricord_dict = {}
    with open(ricord_data_set_file, 'r') as f:
        for line in f.readlines():
            # Values after file name are crop dimensions
            if(len(line.split()) > 1):
                fname, xmin, ymin, xmax, ymax = line.rstrip('\n').split()
                bbox = tuple(int(c) for c in (xmin, ymin, xmax, ymax))
                ricord_dict[fname] = bbox
            else:
                fname = line.rstrip('\n')
                ricord_dict[fname] = None
                
    return ricord_dict

In [3]:
# CODE
ricord_dict = make_ricord_dict(ricord_set_file)
metadata = load_ricord_metadata(ricord_meta_file)

file_count = 0
table={"path":[],"MRN":[],"UID":[],"AcqDateTime":[],"i_value":[]}
for mrn, date, desc, uid in metadata:
    date = date.strftime('%m-%d-%Y')
    uid = uid[-5:]
    study_dir = os.path.join(ricord_dir, 'MIDRC-RICORD-1C-{}'.format(mrn), '*-{}'.format(uid)) # this controls the images being examined
    dcm_files = sorted(glob.glob(os.path.join(study_dir, '*', '*.dcm')))
    
    for i, dcm_file in enumerate(dcm_files):
        # Load DICOM image
        ds = dicom.dcmread(dcm_file)
        acqDT = float(ds.AcquisitionDate+ds.AcquisitionTime)
        # Fill in table
        table["path"].append(dcm_file)
        table["MRN"].append(mrn)
        table["UID"].append(uid)
        table["AcqDateTime"].append(acqDT)
        table["i_value"].append(i)
df = pd.DataFrame(table)

ImportError: Missing optional dependency 'xlrd'. Install xlrd >= 1.0.0 for Excel support Use pip or conda to install xlrd.

In [None]:
# For each MRN:
# find the earliest file that fits the out_fname inside ricord_dict
files = []
filenames = []
unique_mrns = pd.unique(df["MRN"])
for mrn in unique_mrns:
    # Find idxs with this mrn
    idx = df["MRN"]==mrn
    temp_df = df[idx]
    # Sort by ascending AcqDateTime
    temp_df = temp_df.sort_values(by=["AcqDateTime"])
    
    for index, row in temp_df.iterrows():
        uid = row["UID"]
        i = row["i_value"]
        out_fname = 'MIDRC-RICORD-1C-{}-{}-{}.png'.format(mrn, uid, i)
        if out_fname not in ricord_dict:
            continue
        else:
            # Take this row of info
            files.append(row["path"])
            filenames.append(out_fname)
            # break # in order to get only 1 patient

In [None]:
from pydicom.pixel_data_handlers import apply_modality_lut, apply_voi_lut
import cv2

for file_iter_count, file in enumerate(files):
    savefile_name = filenames[file_iter_count]
    if flag_save:
        # Load DICOM image
        ds = dicom.dcmread(file)

        # Verify orientation
        if ds.ViewPosition != 'AP' and ds.ViewPosition != 'PA':
            print('Image from MRN-{} Date-{} UID-{} in position {}'.format(mrn, date, uid, ds.ViewPosition))
            continue

        # Apply transformations if required
        if ds.pixel_array.dtype != np.uint8:
            # Apply LUT transforms
            arr = apply_modality_lut(ds.pixel_array, ds)
            if arr.dtype == np.float64 and ds.RescaleSlope == 1 and ds.RescaleIntercept == 0:
                arr = arr.astype(np.uint16)
            arr = apply_voi_lut(arr, ds)
            arr = arr.astype(np.float64)

            # Normalize to [0, 1]
            arr = (arr - arr.min())/arr.ptp()

            # Invert MONOCHROME1 images
            if ds.PhotometricInterpretation == 'MONOCHROME1':
                arr = 1. - arr

            # Convert to uint8
            image = np.uint8(255.*arr)
        else:
            # Invert MONOCHROME1 images
            if ds.PhotometricInterpretation == 'MONOCHROME1':
                image = 255 - ds.pixel_array
            else:
                image = ds.pixel_array

        # Crop if necessary
        bbox = ricord_dict[out_fname]
        if bbox is not None:
            image = image[bbox[1]:bbox[3], bbox[0]:bbox[2]]

        # Save image
        random_number = np.random.uniform()
        if random_number >= 0.8 and random_number < 0.9:
            split_switch = "validation"
        elif random_number >= 0.9:
            split_switch = "test"
        else:
            split_switch = "train"
            
        save_dir = os.path.join(out_dir, split_switch)
        out_path = Path(save_dir)
        out_path.mkdir(parents=True, exist_ok=True)
        
        save_path = os.path.join(save_dir, savefile_name)
        cv2.imwrite(save_path, image)
        file_count += 1
print('Created {} files'.format(file_count))


# RESIZE IMAGES & SAVE INTO ANOTHER DIRECTORY
Run this in order to resize the created files previously into smaller sizes for faster training.

In [1]:
import numpy as np
import os, sys
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from skimage.transform import resize
import torch
import Rajaraman_ResNet_BS.RajaramanModel as RajaramanModel
import GusarevBoneSuppression.GusarevModel as GusarevModel

import RenGeBoneSuppression.custom_transforms as custom_transforms
import RenGeBoneSuppression.lungVAE.models.VAE as VAE
import RenGeBoneSuppression.lungSegFunctions_for_external_scripts as LF
import RenGeBoneSuppression.net as RGnet


import torchvision.transforms as tvtransforms
import torchvision.transforms.functional as TF

directory = "D:/data/DongrongNetwork/DanielDataSets_Original/"
resized_shape = (224,224)
new_folder_name = "Reduced_224_Ren" # change "Original" to new_folder_name for the new dataset
switch = "Ren"
PATH_BSNET_CKPT = ["./RenGeBoneSuppression/g1_net0474-0.000.pth",
                   "./RenGeBoneSuppression//g2_net0474-0.000.pth"]#"C:/Users/nfdlam/Desktop/GusarevBoneSuppression/runs/6LayerCNN/v5_HQ_noEqualised_177-20-20/network_final.tar"

class BoneSuppression(object):
    def __init__(self, switch=None, PATH_BSNET_CKPT=None,resized_shape=(224,224)):
        self.switch = switch
        self.PATH_BSNET_CKPT = PATH_BSNET_CKPT
        self.resized_shape=resized_shape
        self.device = torch.device("cuda" if (torch.cuda.is_available()) else 'cpu')
        self.BATCH_SIZE = 64
        print(self.device)
    def bone_suppression(self,image):
        image = image.to(self.device)
        if self.switch == "Rajaraman" or self.switch == "Gusarev":
            out = self.bone_suppression_GusarevRajaraman(image)
        elif self.switch == "Ren":
            out = self.bone_suppression_Ren(image)
        else:
            raise RuntimeError("Not a known suppression method")
        return out
    
    def bone_suppression_Ren(self,image, flag_segmentLung=True, flag_compositeBodyAndLung=True):
        # assumes PATH_BSNET_CKPT is a sequence of 2 strings [net1, net2]
        device = self.device
        
        # Lung Segmentation DEFAULT SETTINGS
        
        model="./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt"
        hidden=16
        latent=8
        unet=False
        dicom=False
        no_post = False
        flag_equaliseOriginalImages = False
        flag_normalise = False
        flag_cropping=False
        p = 32 # padding
        
        print("Loading "+model)
        if 'unet' in model:
            unet = True
            hidden = int(1.5*hidden)
        else:
            unet = False
        net_lungSeg = VAE.uVAE(nhid=hidden,nlatent=latent,unet=unet)
        net_lungSeg.load_state_dict(torch.load(model, map_location=device))
        net_lungSeg.to(device)
        #nParam = sum(p.numel() for p in net_lungSeg.parameters() if p.requires_grad)
        #print("Model "+model.split('/')[-1]+" Number of parameters:%d"%(nParam))
        
        # Network download
        #GENERATOR1 = "./g1_net0474-0.000.pth"
        #GENERATOR2 = "./g2_net0474-0.000.pth"
        GENERATOR1 = self.PATH_BSNET_CKPT[0]
        GENERATOR2 = self.PATH_BSNET_CKPT[1]
        input_array_size = (self.BATCH_SIZE, 1, self.resized_shape[0], self.resized_shape[1])
        G1 = RGnet.Generator_first()
        G2 = RGnet.Generator_second()
        G1.load_state_dict(torch.load(GENERATOR1, map_location=device))
        G2.load_state_dict(torch.load(GENERATOR2, map_location=device))

        G1 = G1.float()
        G2 = G2.float()

        G1 = G1.to(device)
        G2 = G2.to(device)
        # Set to testing mode
        G1.eval()
        G2.eval()
        
        # Suppression IN LOOP
        torchresize = tvtransforms.Resize(self.resized_shape, interpolation=TF.InterpolationMode.NEAREST)
        with torch.no_grad():
            if flag_segmentLung:
                input_data , output_data = self.GaryPreprocessingInputData(image.to(device), net_lungSeg.to(device), p,
                                                                      no_post, flag_equaliseOriginalImages, flag_normalise, flag_cropping)
                input_data = input_data.to(device)
            else:
                input_data = torchresize(data[key_source].to(device))
            
            # Bone Suppression
            output1 = G1(input_data.float())
            input2 = torch.cat((input_data.float(), output1), 1)
            out_uncomposited = G2(input2)
            out = out_uncomposited.cpu()
            
        # Cut out the lung mask from the input_data and paste it into the original image.
        if flag_segmentLung and flag_compositeBodyAndLung:
            print("Compositing...")
            out = self.compositeImage(out, output_data, flag_cropping)
            
        return out
        
    def GaryPreprocessingInputData(self, image, net_lungSeg, p=32, no_post=False, 
                               flag_equaliseOriginalImages=False, flag_normalise=False, flag_cropping=False):
        
        # Ren Ge's network is trained on black-bone images
        # Standardise Input Image Monochrome
        stdMono = custom_transforms.StandardiseMonochrome(sample_keys_images=["test"], standard="MONOCHROME2", verbose=False)
        # image going into tform needs to be [CxHxW]
        temp_image, _ = stdMono.tform(image.squeeze().unsqueeze(0))
        image = temp_image.unsqueeze(0)

        # Segment the lung mask
        # Original image is equalised if no_preprocess=False
        no_preprocess = not flag_equaliseOriginalImages # because equalisation will occur in the lung segmentation aspect
        
        output_data = LF.lungSegmentation_maskOnly(self.device, net_lungSeg, image, p=p, 
                                                no_preprocess=no_preprocess, standardisedMonochrome="MONOCHROME1",no_post=no_post)

        # For boneless
        if False:
            print("There are boneless images in the data.")
            output_data[key_boneless] = []
            if flag_equaliseOriginalImages:
                for image in data[key_boneless]:
                    image = image.squeeze().numpy() #[HxW]
                    image = skimage.exposure.equalize_hist(image)
                    output_data[key_boneless].append(torch.Tensor(image).unsqueeze(0).unsqueeze(0)) #[1x1xHxW]
            else:
                for image in data[key_boneless]:
                    image = image.squeeze().numpy() #[HxW]
                    output_data[key_boneless].append(torch.Tensor(image).unsqueeze(0).unsqueeze(0)) #[1x1xHxW]
            output_data[key_boneless] = torch.cat(output_data[key_boneless])

        # Multiply image * lung
        torchresize = tvtransforms.Resize(self.resized_shape, interpolation=TF.InterpolationMode.NEAREST)
        masked = output_data["image"].cpu()*output_data["mask"].cpu()
        # Crop to lung and resize
        if flag_cropping:
            output_data["croppedMaskedImage"] =[]
            output_data["croppedMask"] =[]
            output_data["croppedImage"] =[]
            if False:
                output_data["cropped"+key_boneless]=[]
            for idx2, mask_image in enumerate(output_data["mask"]):
                bb = BoundingBox(mask_image)
                indices = bb.findBox()
                # crop the mask & the masked image
                croppedMaskedImage = masked[idx2,:,indices["topbottom"][0]:indices["topbottom"][1]+1, indices["leftright"][0]:indices["leftright"][1]+1]
                croppedMask = output_data["mask"][idx2,:,indices["topbottom"][0]:indices["topbottom"][1]+1, indices["leftright"][0]:indices["leftright"][1]+1]
                croppedImage = output_data["image"][idx2,:,indices["topbottom"][0]:indices["topbottom"][1]+1, indices["leftright"][0]:indices["leftright"][1]+1]
                if key_boneless in data.keys():
                    croppedBoneless = output_data[key_boneless][idx2,:,indices["topbottom"][0]:indices["topbottom"][1]+1, indices["leftright"][0]:indices["leftright"][1]+1]
                # Resize image to 256x256
                output_data["croppedMaskedImage"].append(torchresize(croppedMaskedImage))
                output_data["croppedMask"].append(torchresize(croppedMask))
                output_data["croppedImage"].append(torchresize(croppedImage))
                if False:
                    output_data["cropped"+key_boneless].append(torchresize(croppedBoneless))
            output_data["croppedMaskedImage"] = torch.stack(output_data["croppedMaskedImage"])
            output_data["croppedMask"] = torch.stack(output_data["croppedMask"])
            output_data["croppedImage"] = torch.stack(output_data["croppedImage"])
            if False:
                output_data["cropped"+key_boneless] = torch.stack(output_data["cropped"+key_boneless])

            # IMPORTANT DATA OUTPUT
            maskedImage = output_data["croppedMaskedImage"]
        else:
            maskedImage = torchresize(masked)

        # Normalise masked images
        if flag_normalise:
            normalised = []
            for image in maskedImage:
                image = image.squeeze().numpy()
                image = normalisation(image)
                image = torch.from_numpy(image).unsqueeze(0).unsqueeze(0) #[1x1xHxW]
                normalised.append(image)
            normalised = torch.cat(normalised)
            input_data = normalised
        else:
            input_data = maskedImage
        return input_data, output_data
    def compositeImage(self, out, output_data, flag_cropping):
        # Re-paste the suppressed lung segment back into the OG image
        image_spatial_size = (out.shape[-2],out.shape[-1])
        torchresize = tvtransforms.Resize(image_spatial_size, interpolation=TF.InterpolationMode.NEAREST)

        if flag_cropping:
            mask = output_data["croppedMask"]
            OG_image = output_data["croppedImage"]
        else:
            mask = output_data["mask"]
            OG_image = output_data["image"]
            mask = torchresize(mask)
            OG_image = torchresize(OG_image)

        composited = []
        for minibatch_idx, lung in enumerate(out):
            mask_current = mask[minibatch_idx,:] #[CxHxW]
            # flip mask_current's Trues to Falses and vice versa
            body_mask = ~mask_current
            body = body_mask.cpu()*OG_image[minibatch_idx,:].cpu()
            if body.shape[-2] != lung.shape[-2] and body.shape[-1] != lung.shape[-1]:
                body = torchresize(body)
            composited.append(body + lung)
        out = torch.stack(composited)
        if image_spatial_size is not None:
            out = torchresize(out)
        return out
    ##################
    # Gusarev and Rajaraman
    ##################
    def bone_suppression_GusarevRajaraman(image):
        device = self.device
        if self.switch == None:
            return image
        else:
            if self.PATH_BSNET_CKPT is None:
                raise RuntimeError("PATH_BSNET_CKPT is None.")

        # Which suppression method?
        input_array_size = (1, 1, self.resized_shape[-2], self.resized_shape[-1])
        if self.switch == "Rajaraman":
            net = RajaramanModel.ResNet_BS(input_array_size)
            net = net.to(device)
        elif self.switch == "Gusarev":
            net = GusarevModel.MultilayerCNN(input_array_size)
            net = net.to(device)
        else:
            raise RuntimeError("Not a known suppression method")
            
        # Load pre-trained parameters into net
        if os.path.isfile(self.PATH_BSNET_CKPT):
                checkpoint = torch.load(self.PATH_BSNET_CKPT, map_location=device)
                start_epoch = checkpoint['epoch_next']
                reals_shown_now = checkpoint['reals_shown']
                net.load_state_dict(checkpoint['model_state_dict'])
        else:
            print("=> NO CHECKPOINT FOUND AT '{}'" .format(PATH_BSNET_CKPT))
            raise RuntimeError("No checkpoint found at specified path.")
        # Set to evaluation mode:
        net.eval()

        # Input data must be a 4D Torch Tensor
        if len(image.shape) != 4:
            raise RuntimeError("Image must be a 4D Torch Tensor.  Current image shape is {}".format(image.shape))
        # Suppress:
        out = net(image.to(device))
        out = out.cpu()
        return out

    
BS = BoneSuppression(switch, PATH_BSNET_CKPT , resized_shape )
for root, dirs, files in os.walk(directory):
    for name in files:
        filepath = os.path.join(root, name)
        ## Load image in filepath, image process, then save
        # Load
        image = plt.imread(filepath) #[range 0-1]
        # Pre-processing
        image = resize(image, resized_shape, order=0) #[range 0-1]
        
        # Bone suppression
        image = torch.from_numpy(image).unsqueeze(0).unsqueeze(0) # 4D Torch Tensor
        image_tensor = BS.bone_suppression(image) # range[0-1]
        image = image_tensor.detach().squeeze().numpy() 
        
        # Post-processing
        pImg = (image*255).astype(np.uint8)
        # Save
        filepath2 = filepath.replace("Original",new_folder_name)
        a = Path(os.path.split(filepath2)[0])
        a.mkdir(parents=True, exist_ok=True)        
        
        I8 = (((pImg - pImg.min()) / (pImg.max() - pImg.min())) * 255).astype(np.uint8)
        img = Image.fromarray(I8)
        img.save(filepath2)
print("Complete.")
print("Last file saved: {}".format(filepath2))

cuda
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE.pt
Compositing...
Loading ./RenGeBoneSuppression/lungVAE/saved_models/lungVAE

In [None]:
image.shape