In [16]:
# latest preprocessing
import cv2, os, glob, h5py, yaml, math
import pathlib
import scipy.io
import pandas as pd
from dataUtils import octSpectralisReader as osr
from dataUtils.preprocessData import preprocessData
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from astropy.convolution import convolve
from scipy.signal import find_peaks
from scipy.io import savemat
from dataUtils.misc import build_mask, sp_noise
from dataUtils.retinaFlatten import retinaFlatten
from dataUtils.retinaDetect import retinaDetector

In [17]:
config_path = "preprocess_config_JH.yaml"

with open(config_path) as file:
    config = yaml.load(file, Loader=yaml.FullLoader)

In [18]:
class processData:
    def __init__(self, config):
        self.layers = config['general']['layers']
        self.bscans = config['general']['bscans']
        self.width = config['general']['width']
        self.sliding_window = config['general']['sliding_window']
        self.height = config['general']['height']
        self.stride = config['general']['stride']
        self.top_offset = config['general']['top_offset']
        self.split_ratio = config['general']['split_ratio']
        self.seed = config['general']['seed']
        self.getPatches = config['general']['getPatches']
        self.crop = config['general']['crop']
        self.constructSyntheticMap = config['general']['constructSyntheticMap']
        
        self.image_path = config['filepaths']['image_path']
        self.label_path = config['filepaths']['label_path']
        self.processed_path = config['filepaths']['processed_path']
        self.group = config['filepaths']['group']
        self.filename = config['filepaths']['filename']
        self.save_filename = config['filepaths']['save_filename']
        
        self.data_list = config['general']['data_list']
        
        self.datasets = {"training":None, "val":None, "test":None}
        self.postprocessed_data = None
        
        self.algorithm = config['algorithm']
        
    def save_data(self, df, mode):

        df.to_csv('{}/{}/{}_reconstruct_data.csv'.format(self.processed_path,self.save_filename, mode))

        f= open("{}/{}/{}_dataset.txt".format(self.processed_path, self.save_filename, mode),"w+")
        for i in range(len(self.datasets[mode])):
            f.write("{}\n".format(self.datasets[mode][i]))
        f.close() 

        with h5py.File(os.path.join(self.processed_path,self.save_filename,'{}_intermediate.hdf5'.format(mode)), 'w') as hf:
            for datatype in self.data_list:
                hf.create_dataset(datatype, data=np.array(self.postprocessed_data[datatype]))
        hf.close()
        
    def prepare_dataset(self):
        """ split into list of cases and training dataset
        """
        if self.filename == 'JH' or self.filename == 'JH_MS':
            # load all the data in the folder and sort by name 
            list_of_cases= [file for root,dirs,files in os.walk(self.image_path) for file in files if file.startswith(self.group)]
            
        elif self.filename == 'MIAMI_HC':
            list_of_cases= glob.glob(os.path.join(self.image_path,self.filename,'*.mat')) 
            
        elif self.filename == "MIAMI_DME":
            list_of_cases= glob.glob(os.path.join(self.image_path,self.filename,'Patient*.mat')) 
            
        training_dataset, self.datasets["test"] = train_test_split(list_of_cases,test_size=self.split_ratio, random_state=self.seed)
        self.datasets["training"], self.datasets["val"] = train_test_split(training_dataset,test_size=self.split_ratio, random_state=self.seed)

    def get_sliding_window(self, data_store, patient_name): # one patient

        '''
        get sliding window with ROI (overlapping functionality can be enabled)
        '''
        
        indices = [*range(0,self.width-self.sliding_window,self.stride)]
        post = {k:np.zeros((self.bscans, len(indices), self.height, self.sliding_window)) for k in self.data_list}
        post['lmap'] =  np.zeros((self.bscans, len(indices), self.layers, self.height,self.sliding_window))
        
        all_data = []
        for scan in range(self.bscans):

            one_data= []
            for idx,z in enumerate(indices):
                
                if self.crop == True:
                    self.window_slide = data_store['rmask'][:,z:z+self.sliding_window, scan]  
                    positions = np.nonzero(self.window_slide)
                    top = positions[0].min() - self.top_offset # get top position - top_offset => CHANGED THIS
                    bottom = top+self.height
                else:
                    top = 0
                    bottom = self.height
                left = z
                right = z+self.sliding_window
                for datatype in self.data_list:
                    if datatype in data_store.keys():
                        if datatype == 'lmap':
                            post[datatype][scan, idx] = data_store[datatype][:,top:bottom, left:right, scan]
                        else:
                            post[datatype][scan, idx] = data_store[datatype][top:bottom, left:right, scan]
                
                one_data.append([patient_name, scan, top, bottom, left, right])
                
            all_data.append(one_data)

        return post, all_data
    
    
    def process_vol_JH(self, file, annotations):
        '''code taken from: https://github.com/steventan0110/OCT_preprocess. Added things like positional map and synthetic map
        '''
        # read oct files
        [header, BScanHeader, slo, BScans] = osr.octSpectralisReader(file)
        header['angle'] = 0

        #initialize options:

        preproc_params = self.algorithm['preproc_params']
        preproc_params['retinadetector_type'] = self.algorithm['types']

        pd = preprocessData(BScans, header, preproc_params, config['algorithm']['probs'], self.algorithm['scanner_type'], annotations, self.data_list)
        pd.preprocess()
        
        return pd.data_store
    
    def process_vol_Miami(self, images, annotations):
        
        header = {}
        header['SizeX'] = self.width
        header['NumBScans'] = images.shape[2]
        header['SizeZ'] = images.shape[0]
        header['Distance'] = 0.13 # I believe this is the slice separation?
        header['ScaleZ'] = 0.0039
        header['ScaleX'] = 0.012 # I believe this is lateral resolution?
        header['angle'] = 0 
        preproc_params = self.algorithm['preproc_params']
        preproc_params['retinadetector_type'] = self.algorithm['types']
        images = images/255
        
        pd = preprocessData(images, header, preproc_params, config['algorithm']['probs'], self.algorithm['scanner_type'], annotations, self.data_list)
        pd.preprocess()
        
        return pd.data_store
    def process_vol_Miami_DME(self, images, annotations):
        # obtained from paper
        
        header = {}
        header['SizeX'] = self.width
        header['NumBScans'] = images.shape[2]
        header['SizeZ'] = images.shape[0]
        header['Distance'] = 0.13 # I believe this is the slice separation?
        header['ScaleZ'] = 0.00387
        header['ScaleX'] = 0.01111 # I believe this is lateral resolution?
        header['angle'] = 0 
        preproc_params = self.algorithm['preproc_params']
        preproc_params['retinadetector_type'] = self.algorithm['types']
        images = images/255
        
        pd = preprocessData(images, header, preproc_params, config['algorithm']['probs'], self.algorithm['scanner_type'], annotations, self.data_list)
        pd.preprocess()
        
        return pd.data_store
    
    def get_wmap(self, data_store):
        lmap_max = np.argmax(data_store['lmap'], axis=0)
        retinal_mask = data_store['rmask']
        lmap_shifted = np.ones(lmap_max.shape)*8
        lmap_shifted[1:] = lmap_max[:-1]
        wmap = 1+5*(retinal_mask).astype(int) + 10*((lmap_max - lmap_shifted)!=0).astype(int)
        data_store['wmap'] = wmap    
        
        return data_store
    
    def process_one_file(self, dataset):
        # get patient name
        if self.filename == 'JH' or self.filename == 'JH_MS':
            patient_name=os.path.splitext(dataset)[0]
            # get image files for each patient
            label_path=os.path.join(self.label_path, self.filename, patient_name+'_label.mat')
            image_path = os.path.join(self.image_path, self.filename, dataset)
            # load corresponding label
            mat = scipy.io.loadmat(label_path)
            annotations=mat['bd_pts'] # 1024*49*9 (all the segmentations for 1 patient)
            data_store = self.process_vol_JH(image_path, annotations)
            data_store = self.get_wmap(data_store)
            
        elif self.filename == 'MIAMI_HC':
            patient_name=os.path.splitext(os.path.basename(dataset))[0]
            # load corresponding label
            data_path=os.path.join(self.image_path,self.filename,patient_name+'.mat')
            mat = scipy.io.loadmat(data_path)
            annotations=mat['Observer1'] # 1024*49*9 (all the segmentations for 1 patient)
            images =mat['volumedata']
            data_store = self.process_vol_Miami(images, annotations)
            data_store = self.get_wmap(data_store)
            
        elif self.filename == 'MIAMI_DME':
            patient_name=os.path.splitext(os.path.basename(dataset))[0]
            # load corresponding label
            data_path=os.path.join(self.image_path,self.filename,patient_name+'.mat')
            mat = scipy.io.loadmat(data_path)
            annotations=np.transpose(mat['annotations'],(1,2,0)) # 1024*49*9 (all the segmentations for 1 patient)
            images =mat['images']
            data_store = self.process_vol_Miami_DME(images,annotations)
            data_store = self.get_wmap(data_store)
        
        return data_store, patient_name
    

        
    def create_dataset(self, mode):
        '''
        by patient
        Outputs:
        img_vol: flattened retina mask, from 0-1, shape: (496, 1024, 49)
        positional_map: flattened relative position map from top to bottom layer, from 0-1, shape: (496, 1024, 49)
        layer_map: flattened ground truth labels, 9 channels (8 layers + background), shape: (9, 496, 1024, 49)
        image_with_noise: synthetic images with noise, shape: (496, 1024, 49) 
        Bscans: original OCT images 
        '''

        postprocessed_data = {k:[] for k in self.data_list}
        postprocessed_data['patient'] = []
        postprocessed_data['meta_data'] = []
        
        if self.filename in ["JH", "JH_MS", "MIAMI_HC", "MIAMI_DME"]:
            for count,dataset in enumerate(self.datasets[mode]):

                data_store, all_data = self.process_one_file(dataset)
                
                if self.getPatches == True:
                    data_store, all_data = self.get_sliding_window(data_store, all_data)

                for datatype in self.data_list:
                    if datatype in data_store.keys():
                        postprocessed_data[datatype].append(data_store[datatype])
                postprocessed_data['patient'].append(dataset)
                if self.getPatches == True: postprocessed_data['meta_data'].append(all_data)

                    # construct layered image for each scan
                    # save the order at which the files are created on a text file
                f = open(("{}/{}/{}_order_files.txt".format(self.processed_path, self.save_filename, mode)), "a")
                f.write(dataset+"\n")
                f.close()
        
        return postprocessed_data
    
        
    def make_dataset(self, mode, save):
        '''
        mode refers to training, test or val
        '''
        pathlib.Path(os.path.join(self.processed_path,self.save_filename)).mkdir(parents=True, exist_ok=True)

        postprocessed_data = self.create_dataset(mode)
        self.postprocessed_data = postprocessed_data
        # save data information and files
        flat_list = [item for sublist in process_data.postprocessed_data['meta_data'] for innerlist in sublist for item in innerlist]
        if save == True:

            df = pd.DataFrame(flat_list, columns = ['patient_name', 'slice_number','top','bottom','left','right']) 
            self.save_data(df, mode)

In [19]:

process_data = processData(config)

process_data.prepare_dataset()

In [20]:

process_data.make_dataset("training", True)
process_data.make_dataset("val", True)
process_data.make_dataset("test", True)

done! 227 outlier points
Preparing S map
done! 1975 outlier points

Preparing S map
done! 2229 outlier points

Preparing S map
done! 1273 outlier points

Preparing S map
done! 2426 outlier points

Preparing S map
done! 191 outlier points
Preparing S map
done! 519 outlier points

Preparing S map
done! 142 outlier points
Preparing S map
done! 111 outlier points
Preparing S map
done! 733 outlier points

Preparing S map
done! 2602 outlier points

Preparing S map
done! 417 outlier points
Preparing S map
done! 1513 outlier points

Preparing S map
done! 740 outlier points

Preparing S map
