# Duke Recon Function

## Importing Libraries & Supporting Scripts

In [1]:
import pdb
import numpy as np
from absl import app, logging
import time
from recon import dcf, kernel, proximity, recon_model, system_model
from utils import io_utils
import matplotlib.pyplot as plt

import skimage.util

from typing import Any, Optional, Tuple
import math
import spect.nmr_timefit as fit
import pickle

from scipy.io import loadmat
import tkinter as tk
from tkinter import filedialog

import os
parent_dir = os.getcwd()
N = 128
I = 1j

  np.bool8: (False, True),


## Defining Main Functions

In [2]:
# Duke's main recon function 
def reconstruct(
    data: np.ndarray,
    traj: np.ndarray,
    kernel_sharpness: float = 0.32,
    kernel_extent: float = 0.32 * 9,
    overgrid_factor: int = 3,
    image_size: int = 128,
    n_dcf_iter: int = 20,
    verbosity: bool = True,
) -> np.ndarray:
    """Reconstruct k-space data and trajectory.

    Args:
        data (np.ndarray): k space data of shape (K, 1)
        traj (np.ndarray): k space trajectory of shape (K, 3)
        kernel_sharpness (float): kernel sharpness. larger kernel sharpness is sharper
            image
        kernel_extent (float): kernel extent.
        overgrid_factor (int): overgridding factor
        image_size (int): target reconstructed image size
            (image_size, image_size, image_size)
        n_pipe_iter (int): number of dcf iterations
        verbosity (bool): Log output messages

    Returns:
        np.ndarray: reconstructed image volume
    """
    start_time = time.time()
    prox_obj = proximity.L2Proximity(
        kernel_obj=kernel.Gaussian(
            kernel_extent=kernel_extent,
            kernel_sigma=kernel_sharpness,
            verbosity=verbosity,
        ),
        verbosity=verbosity,
    )
    system_obj = system_model.MatrixSystemModel(
        proximity_obj=prox_obj,
        overgrid_factor=overgrid_factor,
        image_size=np.array([image_size, image_size, image_size]),
        traj=traj,
        verbosity=verbosity,
    )
    dcf_obj = dcf.IterativeDCF(
        system_obj=system_obj, dcf_iterations=n_dcf_iter, verbosity=verbosity
    )
    recon_obj = recon_model.LSQgridded(
        system_obj=system_obj, dcf_obj=dcf_obj, verbosity=verbosity
    )
    image = recon_obj.reconstruct(data=data, traj=traj)
    del recon_obj, dcf_obj, system_obj, prox_obj
    end_time = time.time()
    execution_time = end_time - start_time
    logging.info("Execution time: {:.2f} seconds".format(execution_time))
    return image

In [3]:
def DisplayImageSlice(image, slice_id=N//2, image_type='modulus'):
    
    # display real component
    if image_type == 'real' or  image_type == 'Real': 
        plt.imshow(np.real(image[:,:,slice_id])) # default is center slice 
        plt.title('Real Component of Image (2D Slice)') 
    # display imaginary component
    elif image_type == 'imag' or  image_type == 'Imag' or image_type == 'imaginary' or image_type == 'Imaginary':
        plt.imshow(np.imag(image[:,:,slice_id]))
        plt.title('Imaginary Component of Image (2D Slice)') # default is center slice 
    else: # display modulus (default)
        plt.imshow(np.abs(image[:,:,slice_id]))
        plt.title('Modulus Component of Image (2D Slice)') # default is center slice 

    plt.colorbar()
    plt.show()

In [4]:
def makeSlide(A):
    ''' displays 3D array as a 2D grayscale image montage'''
    plt.imshow(skimage.util.montage([abs(A[:,:,k]) 
                                     for k in range(0,A.shape[2])], padding_width=1, fill=0))
    plt.show()

## Loading Data

In [5]:
# loading trajectory 
if parent_dir.endswith('-main'):
    os.chdir('..')
    os.chdir('..')
parent_dir = os.getcwd()
    
traj = loadmat('traj/traj_gas_afia.mat')

# k-space x, y, and z points 
kx = traj['data'][:,:,0].ravel()
ky = traj['data'][:,:,1].ravel()
kz = traj['data'][:,:,2].ravel()

# rescaling based on N
#kx = kx*N
#ky = ky*N
#kz = kz*N
# trajectory has already been scaled!

trajlist = np.column_stack((kx, ky, kz))

In [6]:
# phasors from data synthesizer
phasors_128 = np.load((os.path.join(parent_dir, 
                                     'DataSynthesizer/3D_binary_multisphere_radtraj_phasors_128.npy')))
phasors_128 = phasors_128.reshape((phasors_128.shape[0], 1))

## Gas Reconstruction

In [7]:
# running main recon function
image_gas_highsnr = reconstruct(
data=phasors_128,
traj=trajlist,
kernel_sharpness=float(0.14),
kernel_extent=9 * float(0.14),
image_size=int(128),
)

image_gas_highreso = reconstruct(
data=phasors_128,
traj=trajlist,
kernel_sharpness=float(0.32),
kernel_extent=9 * float(0.32),
image_size=int(128),
) 

In [8]:
# need to swap axes and reverse order to get back to original image
image_gas_highreso = np.abs(np.transpose(image_gas_highreso, (1, 2, 0)))[127::-1, 127::-1, 127::-1]
#makeSlide(image_gas_highreso)

## Dissolved Phase Reconstruction

In [9]:
image_dissolved = reconstruct(
            data=phasors_128,
            traj=trajlist,
            kernel_sharpness=float(0.14),
            kernel_extent=9 * float(0.14),
            image_size=int(128),
           )

In [10]:
# need to swap axes and reverse order to get back to original image
image_dissolved = np.abs(np.transpose(image_dissolved, (1, 2, 0)))[127::-1, 127::-1, 127::-1]
#makeSlide(image_dissolved)

## Dixon Reconstruction

In [11]:
def round_up(x: float, decimals: int = 0) -> float:
    return math.ceil(x * 10**decimals) / 10**decimals

In [12]:
# calculating rbc/membrane ratio
def calculate_static_spectroscopy(
    fid: np.ndarray,
    dwell_time: float = 1.95e-05,
    tr: float = 0.015,
    center_freq: float = 34.09,
    rf_excitation: int = 218,
    n_avg: Optional[int] = None,
    n_avg_seconds: int = 1,
    method: str = "voigt",
    plot: bool = False,
) -> Tuple[float, Any]:
    """Fit static spectroscopy data to Voigt model and extract RBC:M ratio.

    The RBC:M ratio is defined as the ratio of the fitted RBC peak area to the membrane
    peak area.
    Args:
        fid (np.ndarray): Dissolved phase FIDs in format (n_points, n_frames).
        dwell_time (float): Dwell time in seconds.
        tr (float): TR in seconds.
        center_freq (float): Center frequency in MHz.
        rf_excitation (int, optional): _description_. Excitation frequency in ppm.
        n_avg (int, optional): Number of FIDs to average for static spectroscopy.
        n_avg_seconds (int): Number of seconds to average for
            static spectroscopy.
        plot (bool, optional): Plot the fit. Defaults to False.

    Returns:
        Tuple of RBC:M ratio and fit object
    """
    t = np.array(range(0, np.shape(fid)[0])) * dwell_time
    t_tr = np.array(range(0, np.shape(fid)[1])) * tr

    start_time=2
    end_time=10

    start_ind = np.argwhere(np.array([round_up(x, 2) for x in t_tr]) == start_time)
    end_ind = np.argwhere(np.array([round_up(x, 2) for x in t_tr]) == end_time)

    if np.size(start_ind) == 0:
        start_ind = [0]
    if np.size(end_ind) == 0:
        end_ind = [np.size(t_tr)]

    start_ind= int(start_ind[int(np.floor(np.size(start_ind) / 2))])

    # calculate number of FIDs to average
    if n_avg:
        n_avg = n_avg
    else:
        n_avg = int(n_avg_seconds / tr)

    end_ind = np.min([len(fid[0, :]) - 1, start_ind + n_avg + 1])
    data_dis_avg = np.average(fid[:, start_ind:end_ind], axis=1)

    fit_obj = fit.NMR_TimeFit(
        ydata=data_dis_avg,
        tdata=t,
        area=np.array([1,1,1]),
        freq= np.array([0, -21.7, -218.0]) * center_freq,
        fwhmL=np.array([8.8, 5.0, 2.0]) * center_freq,
        fwhmG=np.array([0, 6.1, 0]) * center_freq,
        phase=np.array([0, 0, 0]),
        line_broadening=0,
        zeropad_size=np.size(t),
        method=method,
    )
    lb = np.stack(
        (
            [-np.inf, -np.inf, -np.inf],
            [-np.inf, -np.inf, -np.inf],
            [-np.inf, -np.inf, -np.inf],
            [-np.inf, -np.inf, -np.inf],
            [-np.inf, -np.inf, -np.inf],
        )
    ).flatten()
    ub = np.stack(
        (
            [+np.inf, +np.inf, +np.inf],
            [+np.inf, +np.inf, +np.inf],
            [+np.inf, +np.inf, +np.inf],
            [+np.inf, +np.inf, +np.inf],
            [+np.inf, +np.inf, +np.inf],
        )
    ).flatten()
    bounds = (lb, ub)
    fit_obj.fit_time_signal_residual(bounds=bounds)
    
    if plot:
        fit_obj.plot_time_spect_fit()
    rbc_m_ratio = fit_obj.area[0] / np.sum(fit_obj.area[1])
    
    return rbc_m_ratio, fit_obj

In [13]:
# load in dictionary (created using flip cal data) from a real data example
os.chdir(os.path.join(parent_dir, 'DataSynthesizer'))
with open('Xe0067Pre_dict_dyn.pkl', 'rb') as pickle_file:
    dict_dyn = pickle.load(pickle_file)

In [14]:
rbc_m_ratio, _ = calculate_static_spectroscopy(
                fid=dict_dyn["fids_dis"],
                dwell_time=dict_dyn["dwell_time"],
                tr=dict_dyn["tr"],
                center_freq=dict_dyn["freq_center"],
                rf_excitation=dict_dyn["freq_excitation"],
                plot=False,
                )
print(rbc_m_ratio)

0.4430417883756298


In [15]:
# create mask with normal threshold method
def normalize_images(images):
    # initial zero ndarray
    normalized_images = np.zeros_like(images.astype(float))

    # first index is number of images
    #other indices indicates height, width, and depth of the image
    num_images = images.shape[0]

    # computing the minimum and maximum value of the input image for normalization 
    maximum_value, minimum_value = images.max(), images.min()

    # normalize all the pixel values of the images to be from 0 to 1
    for img in range(num_images):
        normalized_images[img, ...] = (images[img, ...] - float(minimum_value)) / float(maximum_value - minimum_value)

    return normalized_images

In [16]:
imgGasHiSNRNorm=normalize_images(abs(image_gas_highsnr))

threshold_value=0.5
mask_hiSNR = abs(imgGasHiSNRNorm) > threshold_value

In [17]:
def correct_b0(
    image: np.ndarray, mask: np.ndarray, max_iterations: int = 100
) -> np.ndarray:
    # correct B0 inhomogeneity

    index = 0
    meanphase = 1

    while abs(meanphase) > 1e-7:
        index = index + 1
        diffphase = np.angle(image)
        meanphase = np.mean(diffphase[mask])  # type: ignore
        image = np.multiply(image, np.exp(-1j * meanphase))
        if index > max_iterations:
            break
    return np.angle(image)  # type: ignore

In [18]:
diffphase = correct_b0(image_gas_highsnr, mask_hiSNR)

# calculate phase shift to separate RBC and membrane
desired_angle = np.arctan2(rbc_m_ratio, 1.0)  # calculated from the flipcal file
current_angle = np.angle(np.sum(image_dissolved[mask_hiSNR > 0]))
delta_angle = desired_angle - current_angle
image_dixon = np.multiply(image_dissolved, np.exp(1j * (delta_angle)))
image_dixon = np.multiply(image_dixon, np.exp(1j * (-diffphase)))

# separate RBC and membrane components

image_rbc = (
        np.imag(image_dixon)
        if np.mean(np.imag(image_dixon)[mask_hiSNR]) > 0
        else -1 * np.imag(image_dixon)  # type: ignore
    )
image_membrane = (
        np.real(image_dixon)
        if np.mean(np.real(image_dixon)[mask_hiSNR]) > 0
        else -1 * np.real(image_dixon)  # type: ignore
    )

## Displaying Gas, Membrane, & RBC

In [33]:
def display_multiple_slices(*arrays):
    '''Displays multiple 3D arrays on the same figure using makeSlide'''

    num_arrays = len(arrays)

    # create a figure with subplots for each array
    fig, axes = plt.subplots(1, num_arrays, figsize=(4 * num_arrays, 4))

    for i, A in enumerate(arrays):
        # display each array using the makeSlide function
        axes[i].imshow(skimage.util.montage([abs(A[:, :, k]) for k in range(0, A.shape[2])], padding_width=1, fill=0))
        if i==0:
            axes[i].set_title('Gas')
        elif i==1:
            axes[i].set_title('Membrane')
        elif i==2:
            axes[i].set_title('RBC')

    plt.tight_layout()
    plt.show()

In [34]:
display_multiple_slices(image_gas_highreso, image_membrane, image_rbc)