This notebook is concerned with using a model (created by unet_dev.ipynb or mrunet_dev.ipynb) to predict the areas of pericardial fat in a (much!) larger subsample of the UKBiobank dataset.  

In [None]:
import pandas as pd

import numpy as np

from mask_utils import load_image,resample_image,pad_voxels

from tensorflow.keras.models import model_from_json

import os

from network_utils import gpu_memory_limit,predict_stochastic
from MultiResUNet.MultiResUNet import MultiResUnet

import pickle

import tempfile

import zipfile

import re

import glob

import tensorflow as tf

import pydicom as dcm
import nibabel as nib

In [None]:
#limit how much GPU RAM can be allocated by this notebook... 8GB is 1/3 of available
gpu_memory_limit(6000)

# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#     try:
#         for gpu in gpus:
#             tf.config.experimental.set_memory_growth(gpu, True)
#     except RuntimeError as e:
#         print(e)

In [None]:
try:
    PADSIZE = np.array(pickle.load(open(os.path.join('data','PADSIZE.pickle'),'rb')))
    PXSPACING = np.array(pickle.load(open(os.path.join('data','PXSPACING.pickle'),'rb')))
except:
    print('PICKLES NOT FOUND OR UNREADABLE. try running extract_dcm_for_wsx.ipynb')
PXAREA = np.product(PXSPACING)

First, load the model

In [None]:
#specify which model to use
modelBaseName = 'mrunet_bayesian_2020-07-13_13-40' 

#location of the actual saved model
modelBaseName = os.path.join('data','models',modelBaseName)

modelParamFile = modelBaseName + '.h5'
modelArchitecture = modelBaseName + '.json'

# with open( modelArchitecture , 'r') as json_file: # this has rotted due to tf upgrades
#     MODEL = model_from_json( json_file.read() )
    
MODEL = MultiResUnet(height=PADSIZE[0],
                     width=PADSIZE[1],
                     n_channels=1,
                     layer_dropout_rate=None,
                     block_dropout_rate=0.25 #extarcted from json file which is now unreadable...
                    )

MODEL.load_weights(modelParamFile)

#hyperparameter N, defined according to quantify_model_performance.ipynb
N = 15

accuracyModelPath = modelBaseName + '_prediction_conversion.pickle'
ACCURACYMODEL = pickle.load(open(accuracyModelPath,'rb'))
#IF THIS BREAKS IN FUTURE:
#coefficient is 1.63920111
#intercept is -0.66730187

#file for writing results
OUTPUT_DIRECTORY = os.path.join('data','pericardial','UKB_segmentations_for_Esmeralda')
RESULTSFILE = os.path.join(OUTPUT_DIRECTORY,'UKB_pericardial_fat_predictions.csv')

Now, load the details for image preprocessing:


In [None]:
#get list of all LAX zipfiles   
allZips = sorted(glob.glob(os.path.join('data','imaging_by_participant','**','*_longaxis.zip'),recursive = True))

In [None]:
def get_manifest(zipfileObject):
    
    allFiles= zipfileObject.namelist()
    
    with tempfile.TemporaryDirectory() as tempDir:
        reg = re.compile('manifest*')
        manifestFiles = [f for f in allFiles if reg.match(f)]

        if len(manifestFiles) != 1:
            print('no manifest found')
            return None
        else:
            zipfileObject.extract(manifestFiles[0],path=tempDir)
            manifest = pd.read_csv(os.path.join(tempDir,manifestFiles[0]) , index_col=False)
            
    return manifest

In [None]:
def first_image_in_series(zipfileObject,listOfDicomFiles):
    
    #first sort the list, as usually the one with the lowest trigger time is also the first one after sorting
    sortedList = np.sort(listOfDicomFiles)
        
    with tempfile.TemporaryDirectory() as tempDir:
        for dicom in sortedList:
            zipfileObject.extract(dicom,path=tempDir)
            triggerTime = dcm.read_file(os.path.join(tempDir,dicom)).TriggerTime
            if triggerTime == 0.0:
                return dicom

def extract_first_4Ch_image(zipfilePath):
    
    '''THIS VERSION DOES NO CHECKS!!!!! IT JUST TRIES TO LOAD THE FIRST FILE'''
    
    zipfileObject = zipfile.ZipFile(zipfilePath)

    manifest = get_manifest(zipfileObject)

    #index for 4ch images
    Index4Ch = (manifest =='CINE_segmented_LAX_4Ch').max(axis=1)
    
    
    if not Index4Ch.any():
        #if nothing labelled as a 4-chamber image, return nothing
        return None,None,None

    else: #if there *are* images labelled as 4Ch
        with tempfile.TemporaryDirectory() as tempDir:
            #get only the 4Chamber ones
            manifest = manifest.loc[Index4Ch,:]
            #separate the series
            series = manifest.groupby(['series discription','seriesid'])
            #get the date
            imagedDate = manifest['date'].iloc[0]

            if series.count().shape[0] == 1:    
                #if there is only one series used, then get the first image from that one.
                firstDicom = first_image_in_series(zipfileObject,manifest['filename'].values)
                zipfileObject.extract(firstDicom,path=tempDir)
            
            else: 
#                 print('more than one series found...')
                #if there is more than one series, do some logic

                #filter for number of images - should be exactly 50
                manifest  = series.filter(lambda x: x.count().max() == 50)

                #get all the first dicoms...
                firstDicoms = manifest.groupby(['series discription','seriesid']).apply(lambda x: first_image_in_series(zipfileObject,x['filename']))

                #get the series times out...
                firstDicoms.apply(lambda x: zipfileObject.extract(x,path=tempDir))
                seriesTime  = firstDicoms.apply(lambda x: dcm.read_file(os.path.join(tempDir,x)).SeriesTime)

                #and extract the latest one (assuming it will be better...)
                firstDicom = firstDicoms.values[ np.argmax(seriesTime.values) ]

            try:
                dicom_location = os.path.join(tempDir,firstDicom)
                image,spacing,dicom_object = load_image(dicomPath=dicom_location,desiredPxSpacing=PXSPACING, padSize=PADSIZE,return_dicom_object=True)
                return image,imagedDate,dicom_object
            except:
                return None,None,None

In [None]:
def get_affine(d):
    # Copyright 2017.
    # Author: Wenjia Bai, Biomedical Image Analysis Group, Imperial College London.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ==============================================================================
    X = d.Columns
    Y = d.Rows
    Z =1 # because in this context all images are 2D

    dx = float(d.PixelSpacing[1])
    dy = float(d.PixelSpacing[0])

    # DICOM coordinate (LPS)
    #  x: left
    #  y: posterior
    #  z: superior
    # Nifti coordinate (RAS)
    #  x: right
    #  y: anterior
    #  z: superior
    # Therefore, to transform between DICOM and Nifti, the x and y coordinates need to be negated.
    # Refer to
    # http://nifti.nimh.nih.gov/pub/dist/src/niftilib/nifti1.h
    # http://nifti.nimh.nih.gov/nifti-1/documentation/nifti1fields/nifti1fields_pages/figqformusage

    # The coordinate of the upper-left voxel of the first and second slices
    pos_ul = np.array([float(x) for x in d.ImagePositionPatient])
    pos_ul[:2] = -pos_ul[:2]

    # Image orientation
    axis_x = np.array([float(x) for x in d.ImageOrientationPatient[:3]])
    axis_y = np.array([float(x) for x in d.ImageOrientationPatient[3:]])
    axis_x[:2] = -axis_x[:2]
    axis_y[:2] = -axis_y[:2]

    if Z >= 2:
        # Read a dicom file at the second slice
        d2 = dicom.read_file(os.path.join(dir[1], sorted(os.listdir(dir[1]))[0]))
        pos_ul2 = np.array([float(x) for x in d2.ImagePositionPatient])
        pos_ul2[:2] = -pos_ul2[:2]
        axis_z = pos_ul2 - pos_ul
        axis_z = axis_z / np.linalg.norm(axis_z)
    else:
        axis_z = np.cross(axis_x, axis_y)

    # Determine the z spacing
    if hasattr(d, 'SpacingBetweenSlices'):
        dz = float(d.SpacingBetweenSlices)
    elif Z >= 2:
        print('Warning: can not find attribute SpacingBetweenSlices. Calculate from two successive slices.')
        dz = float(np.linalg.norm(pos_ul2 - pos_ul))
    else:
        print('Warning: can not find attribute SpacingBetweenSlices. Use attribute SliceThickness instead.')
        dz = float(d.SliceThickness)

    # Affine matrix which converts the voxel coordinate to world coordinate
    affine = np.eye(4)
    affine[:3,0] = axis_x * dx
    affine[:3,1] = axis_y * dy
    affine[:3,2] = axis_z * dz
    affine[:3,3] = pos_ul
    
    return affine

In [None]:
#FIXMEEEE update this list to match the output arguments in network_utils/predict_stochastic
# RESNAMES = ['consensus','uncertainty','meanArea (mm2)','stdArea (mm2)','mpDsc','gDsc','mpIou','gIou']
RESNAMES = ['meanArea (cm2)','stdArea (cm2)','predicted DSC']

def get_feid(zipfilePath):
    return os.path.basename(zipfilePath)[:7]

def quantify_fat(zipfilePath):
    
    feid = get_feid(zipfilePath)
    #create dictionary for returning results.
    resultDict = {'f.eid':feid}
    
    #extract the pixels for each image.
    im,imagedDate,dicom_object = extract_first_4Ch_image(zipfilePath)
    
    if im is not None:
        resultDict['date'] = imagedDate
        im = im.reshape((1,*im.shape,1))
        res = predict_stochastic(MODEL,N,ACCURACYMODEL,im) #FIXMEE remove the unnecessary metrics
        
        #get the actual predicted segmentation (boolean version)
        segmentation = res[0].squeeze()
        
        #resize to original resolution and dimensions
        segmentation = resample_image(segmentation,PXSPACING,np.array(dicom_object.PixelSpacing))
        segmentation = pad_voxels(segmentation,dicom_object.pixel_array.squeeze().shape)
                
        assert all([s==o for s,o in zip(segmentation.shape,dicom_object.pixel_array.shape)]),"you've broken something in the image dimensions"
        
        #output directory for segmentation
        output_directory = os.path.join(OUTPUT_DIRECTORY,str(feid))
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
        
        #locations of the output nifti files...
        img_loc = os.path.join(output_directory,'la_4ch_ED.nii.gz')
        seg_loc = os.path.join(output_directory,'pcf_seg_la_4ch_ED.nii.gz')
        
        #make niftis from dicom object...
        vol, pixdim, affine = dn.dicom_to_volume([dicom_object])        
        segmentation = segmentation.reshape(*vol.shape).astype('int')
        affine = get_affine(dicom_object)
        
        #convert coordinares and write nifti for image
        dn.write_nifti(img_loc, vol, affine)
        vol, pixdim, _ = dn.dicom_to_volume([dicom_object])
        
        assert all([s==o for s,o in zip(segmentation.shape,vol.shape)]),"you've broken something in the image dimensions"
        dn.write_nifti(seg_loc, segmentation, affine)        
        
        #wrap quantitative results up into a dict for easy DataFram-ing
        resultDict.update(dict(zip(RESNAMES,res[2:])))

        #ensure that units of area are correct...
        resultDict['meanArea (cm2)'] *= (PXAREA/100)
        resultDict['stdArea (cm2)'] *= (PXAREA/100)
        return resultDict
    else:
        return resultDict
    
    

In [None]:
if os.path.isfile(RESULTSFILE):
    results = pd.read_csv(RESULTSFILE)
else:
    #create a dataframe to store results
    results = pd.DataFrame()

for i,zipfilePath in enumerate(allZips):
    if i not in results.index:
        result = quantify_fat(zipfilePath)
        results = results.append(result,ignore_index=True)
    if i % 100 == 0:
        results.to_csv(RESULTSFILE,index=False)

results.to_csv(RESULTSFILE,index=False)