In [1]:
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
import cv2
from myshow import myshow
import sys
from scipy import linalg
from scipy.io import loadmat, savemat
import cardiacDicomGlobals as cdg
%matplotlib inline


In [2]:
doGlobalFreqSearch = False
multiCoilProcessing = False # index of the one we have to do SOS on
channelList = [5,6,7]
#0 good bicarb
#2: good pyruvate ex
#3: good pyruvate ex !!
#7 good bicarb

In [3]:
def closestIndex(f, faxis):
    return np.argmin(np.abs(faxis-f))
def floatToInt2(img):
    img = np.uint8(255 * (img / img.max()))
    return img

def imageGrad(img,kernelSize):
    #$laplacian = cv2.Laplacian(testImg, cv2.CV_64F)
    sobelx = cv2.Sobel(img,cv2.CV_64F,1,0,ksize=kernelSize)
    sobely = cv2.Sobel(img,cv2.CV_64F,0,1,ksize=kernelSize)
    sobelxy = np.sqrt(np.multiply(sobelx, sobelx) + np.multiply(sobely, sobely))
    return sobelxy

def imageThresh(img):
    # global thresholding
    ret1,th1 = cv2.threshold(img,80,255,cv2.THRESH_BINARY)
    
    # otsu thresholding
    #ret2,th2 = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    
    return th1

def imageAnd(im1,im2):
    return cv2.bitwise_and(im1,im2)

def meanUpperQuantileImage(img, mask, q):
    # select mask pixels
    img = img[mask == mask.max()]
    img = img.flatten()
    
    # pick the upper quartile of pixels
    uq = np.quantile(img, q)
    uqpixels = img[img > uq]
    return np.mean(uqpixels)

def sumImaginaryComponent(img, mask):
    imgAbsImag = np.abs(np.imag(img))
    imgMaskPixels = imgAbsImag[mask>0]
    return np.sum(imgMaskPixels)

# works ok
def gradientObjectiveFunction(img, kernelSize, mask):
    imgGrad = imageGrad(img,kernelSize)
    return meanUpperQuantileImage(imgGrad,mask, .90)

# doesn't work too well
def gradientObjectiveFunction2(img, kernelSize, mask):
    imgGrad = imageGrad(img,kernelSize)
    gradInt2 = floatToInt2(imgGrad)
    imgInt2 = floatToInt2(img)
    thresh = imageThresh(imgInt2)
    jointMask = imageAnd(thresh,mask)
    imgMaskPixels = img[jointMask>0]
    uq = np.quantile(img, .75)
    return np.mean(imgMaskPixels)
    
def calcSNR(img, noise):
    noiseEst = np.std(noise[0:20,0:20])
    return (img / noiseEst)  

In [4]:
fmin = -45
fmax = 45
df = 5
nf = int((fmax - fmin)/df + 1)
freqAxis = np.linspace(fmin, fmax, nf)

In [5]:
# construct the rectangular matrix
J = 1j
PI = np.pi


pind = 0
for currentPatient in cdg.patientList:

    print('patient '+str(pind))

    timeAxis = loadmat(currentPatient.readout)['t']


    ft = np.outer(freqAxis, timeAxis)
    A = np.exp(-2*J*PI*ft)
    Ainv = linalg.pinv2(A)# seems more stable than np.linalg.pinv()
    Ainv = np.transpose(Ainv) # should be necessary? need to double check the above lines

    cjmat = np.zeros((nf, nf), dtype=np.cdouble)
    ind = 0
    for f0 in freqAxis:

        # create the column vector of exp(J_w_t_k)
        ft = f0*timeAxis
        y = np.transpose(np.exp(-2*J*PI*ft))
        cj = np.matmul(Ainv, y)
        cjmat[ind,:] = np.squeeze(cj)
        ind += 1

    
    
    b0mag = sitk.ReadImage(currentPatient.b0mag)
    b0phase = sitk.ReadImage(currentPatient.b0phase)
    loc = sitk.ReadImage(currentPatient.loc)
    c13 = sitk.ReadImage(currentPatient.c13)
    mask = cv2.imread(currentPatient.mask, 0)

    #dilate mask
    kernel = np.ones((5,5),np.uint8)
    maskDilation = cv2.dilate(mask, kernel, iterations = 1)


    mask3D = np.expand_dims(maskDilation, axis=0)
    maskSITK = sitk.GetImageFromArray(mask3D)
    maskSITK.CopyInformation(loc) 
    
    b0phasePixels = np.squeeze(sitk.GetArrayFromImage(b0phase))
    medianFiltered = cv2.medianBlur(b0phasePixels,3)

    #make new sitk image from filtered pixel array
    medianFiltered3D = np.expand_dims(medianFiltered, axis=0)
    b0phasef = sitk.GetImageFromArray(medianFiltered3D)
    b0phasef.CopyInformation(b0phase) 
    
    
    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(c13)
    resample.SetInterpolator(sitk.sitkBSpline)
    resample.AddCommand(sitk.sitkProgressEvent, lambda: print("\rProgress: {0:03.1f}%...".format(100*resample.GetProgress()),end=''))
    resample.AddCommand(sitk.sitkProgressEvent, lambda: sys.stdout.flush())
    b0phase_r = resample.Execute(b0phasef)
    b0mag_r = resample.Execute(b0mag)
    mask_r = resample.Execute(maskSITK)
    loc_r = resample.Execute(loc)
    
    
    threshInput = np.squeeze(sitk.GetArrayFromImage(loc_r))
    threshInput = np.abs(threshInput) # pixels go slightly negative somehow
    threshInput = np.uint8(255.0 * threshInput / threshInput.max())
    ret,thresh1 = cv2.threshold(threshInput,20,255,cv2.THRESH_BINARY)

    kernel = np.ones((3,3),np.uint8)
    locMask = cv2.dilate(thresh1, kernel,iterations = 2)
    
    
    matFileBase = currentPatient.mfr
    sampleFile= 'matFilesIntermediate/' + matFileBase + '_f' + str(1) + '.mat'
    array = loadmat(sampleFile)['bb']
    array = np.squeeze(np.cdouble(array))
    
    
    
    #load mat files for multi frequency recon
    matFileBase = currentPatient.mfr
    sampleFile= 'matFilesIntermediate/' + matFileBase + '_f' + str(1) + '.mat'
    array = loadmat(sampleFile)['bb']
    array = np.squeeze(np.cdouble(array))
    arrayShape = array.shape
    nx = arrayShape[0]
    ny = arrayShape[1]
    ntime = arrayShape[2]
    nmet = arrayShape[3]


    #for storing abs images 
    mfr = np.zeros((nx, ny, ntime, nmet, nf), dtype = np.cdouble)

    #for storing complex images 
    if multiCoilProcessing:
        ncoils = arrayShape[4]
        mfrc = np.zeros((nx, ny, ntime, nmet, ncoils, nf), dtype = np.cdouble)
    else:
        mfrc = np.zeros((nx, ny, ntime, nmet, nf), dtype = np.cdouble)

    print("image array shape")
    print(mfr.shape)

    for f in range(len(freqAxis)):

        #print('on frequency '+str(f+1)+' of ' + str(len(freqAxis)))
        currentFile= 'matFilesIntermediate/' + matFileBase + '_f' + str(f+1) + '.mat'
        if multiCoilProcessing:
            pixelArrayComplex = loadmat(currentFile)['bb']
            pixelArrayComplex = np.squeeze(np.cdouble(pixelArrayComplex))
            mfrc[:,:,:,:,:,f] = pixelArrayComplex

            # do the SOS here in the magnitude channel
            for coil in range(ncoils):       
                if coil in channelList:
                      mfr[:,:,:,:,f] = mfr[:,:,:,:,f] + np.abs(pixelArrayComplex[:,:,:,:,coil])   
            mfr[:,:,:,:,f] = np.sqrt(mfr[:,:,:,:,f] / len(channelList) )

        else:
            #pixelArray = loadmat(currentFile)['bbabs']
            pixelArrayComplex = loadmat(currentFile)['bb']
            #mfr[:,:,:,:,f] = pixelArray
            mfr[:,:,:,:,f] = pixelArrayComplex
            mfrc[:,:,:,:,f] = pixelArrayComplex
            
            
    gfmin = -15
    gfmax = 15
    ngf = int((gfmax - gfmin)/df + 1)
    globalFreqSearch = np.linspace(gfmin, gfmax, ngf)
    globalOnResInd = closestIndex(0, globalFreqSearch)
    onResInd = closestIndex(0, freqAxis)



    mf = np.zeros((nx, ny, ntime, nmet, ngf), dtype = np.cdouble)
    mfi = np.zeros((nx, ny, ntime, nmet, ngf), dtype = np.cdouble)



    #grab the maps and masks, apply appropriate scaling
    b0map = np.squeeze(sitk.GetArrayFromImage(b0phase_r))
    b0map =  b0map * 1.070 / 4.257
    maskPixels = np.squeeze(sitk.GetArrayFromImage(mask_r))


    for it in range(ntime):
        for im in range(nmet):
            for ix in range(nx):
                for iy in range(ny):
                    for gf in range(ngf):

                        globalShift = globalFreqSearch[gf]

                        #if False:
                        if locMask[ix,iy] == 0:
                            mf[ix,iy,it,im,gf]  = mfr[ix,iy,it,im,onResInd]
                            mfi[ix,iy,it,im,gf] = mfr[ix,iy,it,im,onResInd]

                        else:
                            localFreq = b0map[ix,iy] 
                            indF = closestIndex(localFreq + globalShift, freqAxis)
                            MFICoeffs = np.squeeze(cjmat[indF,:])             

                            # segmented
                            mf[ix,iy,it,im,gf] = mfr[ix,iy,it,im,indF]

                            # since MFI is phase sensitive, have to apply channelwise
                            if multiCoilProcessing:
                                for coil in channelList:
                                    pixelFrequencyList = np.squeeze(mfrc[ix,iy,it,im,coil,:])
                                    weightedSum = np.dot(MFICoeffs, pixelFrequencyList)
                                    mfi[ix,iy,it,im,gf] = mfi[ix,iy,it,im,gf] + np.abs(weightedSum)
                                mfi[ix,iy,it,im,gf] = np.sqrt(mfi[ix,iy,it,im,gf] / len(channelList))
                            else:
                                # MFI, single channel
                                pixelFrequencyList = np.squeeze(mfrc[ix,iy,it,im,:])
                                MFICoeffs = np.squeeze(cjmat[indF,:])             
                                weightedSum = np.dot(MFICoeffs, pixelFrequencyList)
                                #mfi[ix,iy,it,im,gf] = np.abs(weightedSum)
                                mfi[ix,iy,it,im,gf] = weightedSum
                            

    

    maskPixels = np.squeeze(sitk.GetArrayFromImage(mask_r))
    kernelSize = 3
    numSkipEachMet = [2,0,1] # filter first numSkip images to eliminate pyruvate artifacts
    peakTimeSignal = np.zeros((3,ntime))

    # first find the image with the best signal for each metabolite
    for it in range(ntime):
        for im in range(3):                      
            numSkip = numSkipEachMet[im]
            if it > numSkip - 1: # filter first images to eliminate artifacts (pixel overrange, etc)    
                img = np.abs(mfi[:,:,it,im, globalOnResInd])
                peakTimeSignal[im, it] = meanUpperQuantileImage(img, maskPixels, .75)

    # find the peak
    peakTime = np.zeros(3)
    for im in range(3):   
        peakTime[im] = np.argmax(peakTimeSignal[im,:])
        
        
        
    peakSignalUncorr = np.zeros((3,ngf))
    peakSignalSeg = np.zeros((3,ngf))
    peakSignalMFI = np.zeros((3,ngf))

    maxSigUncorr = np.zeros((3,ngf))
    maxSigSeg = np.zeros((3,ngf))
    maxSigMFI = np.zeros((3,ngf))

    for gf in range(ngf):

        indf = closestIndex(globalFreqSearch[gf], freqAxis)

        for im in range(3):                      

            im_unc = mfr[:,:, int(peakTime[im]), im, indf]
            im_mfi = mfi[:,:, int(peakTime[im]), im, gf]
            im_seg =  mf[:,:, int(peakTime[im]), im, gf]
            noise = mf[:,:, 7, im, gf]
            im_unc  = calcSNR(np.abs(im_unc), noise)
            im_mfi  = calcSNR(np.abs(im_mfi), noise)
            im_seg  = calcSNR(np.abs(im_seg), noise)


            #sumImaginaryComponent
            #minImagUncorr[im,gf] = sumImaginaryComponent(im_unc, maskPixels)
            #minImagSeg[im,gf]    = sumImaginaryComponent(im_mfi, maskPixels)
            #minImagMFI[im,gf]    = sumImaginaryComponent(im_seg, maskPixels)

            maxSigUncorr[im,gf] = meanUpperQuantileImage(np.abs(im_unc), maskPixels, .75)
            maxSigSeg[im,gf]    = meanUpperQuantileImage(np.abs(im_mfi), maskPixels, .75)
            maxSigMFI[im,gf]    = meanUpperQuantileImage(np.abs(im_seg), maskPixels, .75)

            #gradient sharpness based processing
            peakSignalUncorr[im,gf] = gradientObjectiveFunction(np.abs(im_unc), kernelSize, maskPixels)
            peakSignalMFI[im,gf]    = gradientObjectiveFunction(np.abs(im_mfi), kernelSize, maskPixels)
            peakSignalSeg[im,gf]    = gradientObjectiveFunction(np.abs(im_seg), kernelSize, maskPixels)


    im = 2  # which met to optimize on  
    bestGlobalShiftUncorr = np.argmax(peakSignalUncorr[im,:])    
    bestGlobalShift       = np.argmax(peakSignalSeg[im,:])

    print('best global shift is at '+str(globalFreqSearch[bestGlobalShift])+' Hz')
    print('best global shift-only is at '+str(globalFreqSearch[bestGlobalShiftUncorr])+' Hz')
    
    # no correction
    img_unc = np.abs(np.squeeze(mfr[:, :, :, :, onResInd]))
    #global shift only 
    img_gs = np.abs(np.squeeze(mfr[:, :, :, :, closestIndex(globalFreqSearch[bestGlobalShift], freqAxis)]))
    #global shift + freq segmented
    img_seg = np.abs(np.squeeze(mf[:,:,:,:,bestGlobalShift]))
    #global shift + MFI
    img_mfi = np.abs(np.squeeze(mfi[:,:,:,:,bestGlobalShift]))

    
    
    dicomKey = 'ImagingFrequency'
    keys = loc.GetMetaDataKeys()
    h1freq = float(b0mag.GetMetaData('0018|0084'))
    c13FreqAdjustment = globalFreqSearch[bestGlobalShift] * 1e-6
    c13freq = float(c13.GetMetaData('0018|0084'))

    print('1H freq (dicom header) = '+ str(h1freq))
    print('13C freq (dicom header) = '+ str(c13freq))
    print('13C freq adjustment (from reconstruction) = '+ str(c13FreqAdjustment))

    ratio = (c13freq+c13FreqAdjustment)/h1freq
    print('ratio = '+str(ratio))


    fileDict = dict()
    fileDict["img_unc"] = img_unc
    fileDict['img_gs'] = img_gs
    fileDict['img_seg'] = img_seg
    fileDict['img_mfi '] = img_mfi 
    fileDict['h1freq'] = h1freq
    fileDict['c13freq'] = c13freq
    fileDict['globalFreqSearch'] = globalFreqSearch
    fileDict['peakSignalUncorr'] = peakSignalUncorr
    fileDict['peakSignalSeg'] = peakSignalSeg
    fileDict['peakSignalMFI'] = peakSignalMFI
    fileDict['freqAxis'] = freqAxis

    fileDict['maxSigUncorr'] = maxSigUncorr
    fileDict['maxSigSeg'] = maxSigSeg
    fileDict['maxSigMFI'] = maxSigMFI


    outFileName = "B0CorrectedImages/"+currentPatient.mfr+".mat"
    savemat(outFileName, fileDict)

    pind +=1

patient 0
Progress: 100.0%...image array shape
(128, 128, 8, 3, 19)
best global shift is at 0.0 Hz
best global shift-only is at 15.0 Hz
1H freq (dicom header) = 127.736282
13C freq (dicom header) = 32.12431
13C freq adjustment (from reconstruction) = 0.0
ratio = 0.2514893145238093
patient 1
Progress: 100.0%...image array shape
(128, 128, 8, 3, 19)
best global shift is at 5.0 Hz
best global shift-only is at 10.0 Hz
1H freq (dicom header) = 127.736282
13C freq (dicom header) = 32.12431
13C freq adjustment (from reconstruction) = 4.9999999999999996e-06
ratio = 0.2514893536669558
patient 2
Progress: 100.0%...image array shape
(128, 128, 8, 3, 19)
best global shift is at 5.0 Hz
best global shift-only is at 10.0 Hz
1H freq (dicom header) = 127.736239
13C freq (dicom header) = 32.124299
13C freq adjustment (from reconstruction) = 4.9999999999999996e-06
ratio = 0.2514893522111607
patient 3
Progress: 100.0%...image array shape
(128, 128, 8, 3, 19)
best global shift is at 5.0 Hz
best global shif