In [8]:
import os
import warnings
import json
from typing import Tuple, Callable, Union
import nibabel
import numpy as np
import numba

from petpal.kinetic_modeling.fit_tac_with_rtms import (
                                                       get_rtm_method,
                                                       get_rtm_output_size,
                                                       get_rtm_kwargs)
from petpal.utils.time_activity_curve import TimeActivityCurveFromFile
from petpal.utils.image_io import safe_load_4dpet_nifti
from petpal.kinetic_modeling import graphical_analysis
from petpal.kinetic_modeling.reference_tissue_models import weight_tac_decay, weight_tac_simple
from petpal.input_function.blood_input import read_plasma_glucose_concentration
from petpal.utils.image_io import safe_load_tac, safe_copy_meta, validate_two_images_same_dimensions, ScanTimingInfo, load_metadata_for_nifti_with_same_filename, get_half_life_from_meta, safe_load_meta

In [None]:
def apply_rtm2_to_all_voxels(tac_times_in_minutes: np.ndarray,
                             tgt_image: np.ndarray,
                             ref_tac_vals: np.ndarray,
                             mask_img: np.ndarray,
                             method: str = 'srtm2',
                             **analysis_kwargs) -> np.ndarray:
    """
    Generates parametric images for 4D-PET data using the SRTM2 reference tissue method.

    Args:
        tac_times_in_minutes (np.ndarray): A 1D array representing the reference TAC and PET frame
            times in minutes.
        tgt_image (np.ndarray): A 4D array representing the 3D PET image over time.
            The shape of this array should be (x, y, z, time).
        ref_tac_vals (np.ndarray): A 1D array representing the reference TAC values. This array
            should be of the same length as `tac_times_in_minutes`.
        mask_img (np.ndarray): A 3D array representing the brain mask for `tgt_image`, where brain
            regions are labelled 1 and non-brain regions are labelled 0. This is made necessary in
            order to save time during computation. 

    Returns:
        params_img (np.ndarray): A 4D array with RTM parameter fit results based on the supplied
            method.
    """
    bounds = None
    if "bounds" in analysis_kwargs:
        bounds = True
    analysis_func = get_rtm_method(method=method,bounds=bounds)
    img_dims = tgt_image.shape
    output_shape = get_rtm_output_size(method=method)
    params_img = np.zeros((img_dims[0], img_dims[1], img_dims[2], output_shape), float)

    for i in range(0, img_dims[0], 1):
        for j in range(0, img_dims[1], 1):
            for k in range(0, img_dims[2], 1):
                if mask_img[i,j,k]>0.5:
                    analysis_vals = analysis_func(tac_times_in_minutes=tac_times_in_minutes,
                                                  ref_tac_vals=ref_tac_vals,
                                                  tgt_tac_vals=tgt_image[i, j, k, :],
                                                  **analysis_kwargs)
                    params_img[i,j,k] = analysis_vals[0]

    return params_img

In [None]:
def apply_rtm2_to_all_voxels_scan_timing(
        tac_times_in_minutes: np.ndarray,
                             tgt_image: np.ndarray,
                             ref_tac_vals: np.ndarray,
                             mask_img: np.ndarray,
                             method: str = 'srtm2',
                             **analysis_kwargs) -> np.ndarray:
    """
    Generates parametric images for 4D-PET data using the SRTM2 reference tissue method.

    Args:
        tac_times_in_minutes (np.ndarray): A 1D array representing the reference TAC and PET frame
            times in minutes.
        tgt_image (np.ndarray): A 4D array representing the 3D PET image over time.
            The shape of this array should be (x, y, z, time).
        ref_tac_vals (np.ndarray): A 1D array representing the reference TAC values. This array
            should be of the same length as `tac_times_in_minutes`.
        mask_img (np.ndarray): A 3D array representing the brain mask for `tgt_image`, where brain
            regions are labelled 1 and non-brain regions are labelled 0. This is made necessary in
            order to save time during computation. 

    Returns:
        params_img (np.ndarray): A 4D array with RTM parameter fit results based on the supplied
            method.
    """
    bounds = None
    if "bounds" in analysis_kwargs:
        bounds = True
    analysis_func = get_rtm_method(method=method,bounds=bounds)
    img_dims = tgt_image.shape
    output_shape = get_rtm_output_size(method=method)
    params_img = np.zeros((img_dims[0], img_dims[1], img_dims[2], output_shape), float)

    for i in range(0, img_dims[0], 1):
        for j in range(0, img_dims[1], 1):
            for k in range(0, img_dims[2], 1):
                if mask_img[i,j,k]>0.5:
                    analysis_vals = analysis_func(tac_times_in_minutes=tac_times_in_minutes,
                                                  ref_tac_vals=ref_tac_vals,
                                                  tgt_tac_vals=tgt_image[i, j, k, :],
                                                  **analysis_kwargs)
                    params_img[i,j,k] = analysis_vals[0]

    return params_img

In [2]:
def apply_rtm2_to_all_voxels_weight_simple(scan_timing: ScanTimingInfo,
                                           tgt_image: np.ndarray,
                                           ref_tac_vals: np.ndarray,
                                           mask_img: np.ndarray,
                                           method: str = 'srtm2',
                                           **analysis_kwargs) -> np.ndarray:
    """
    Generates parametric images for 4D-PET data using the SRTM2 reference tissue method.

    Args:
        tac_times_in_minutes (np.ndarray): A 1D array representing the reference TAC and PET frame
            times in minutes.
        tgt_image (np.ndarray): A 4D array representing the 3D PET image over time.
            The shape of this array should be (x, y, z, time).
        ref_tac_vals (np.ndarray): A 1D array representing the reference TAC values. This array
            should be of the same length as `tac_times_in_minutes`.
        mask_img (np.ndarray): A 3D array representing the brain mask for `tgt_image`, where brain
            regions are labelled 1 and non-brain regions are labelled 0. This is made necessary in
            order to save time during computation. 

    Returns:
        params_img (np.ndarray): A 4D array with RTM parameter fit results based on the supplied
            method.
    """
    bounds = None
    if "bounds" in analysis_kwargs:
        bounds = True
    analysis_func = get_rtm_method(method=method,bounds=bounds)
    img_dims = tgt_image.shape
    output_shape = get_rtm_output_size(method=method)
    params_img = np.zeros((img_dims[0], img_dims[1], img_dims[2], output_shape), float)

    for i in range(0, img_dims[0], 1):
        for j in range(0, img_dims[1], 1):
            for k in range(0, img_dims[2], 1):
                if mask_img[i,j,k]>0.5:
                    tac_vals = tgt_image[i, j, k, :]
                    voxel_uncertainties = weight_tac_simple(tac_durations_in_minutes=scan_timing.duration,
                                                            tac_vals=tac_vals)
                    analysis_vals = analysis_func(tac_times_in_minutes=scan_timing.center,
                                                  ref_tac_vals=ref_tac_vals,
                                                  tgt_tac_vals=tac_vals,
                                                  uncertainties=voxel_uncertainties,
                                                  **analysis_kwargs)
                    params_img[i,j,k] = analysis_vals[0]

    return params_img

In [None]:
def apply_rtm2_to_all_voxels_weight_decay(scan_timing: ScanTimingInfo,
                                           tgt_image: np.ndarray,
                                           ref_tac_vals: np.ndarray,
                                           mask_img: np.ndarray,
                                           half_life: float, 
                                           method: str = 'srtm2',
                                           **analysis_kwargs) -> np.ndarray:
    """
    Generates parametric images for 4D-PET data using the SRTM2 reference tissue method.

    Args:
        tac_times_in_minutes (np.ndarray): A 1D array representing the reference TAC and PET frame
            times in minutes.
        tgt_image (np.ndarray): A 4D array representing the 3D PET image over time.
            The shape of this array should be (x, y, z, time).
        ref_tac_vals (np.ndarray): A 1D array representing the reference TAC values. This array
            should be of the same length as `tac_times_in_minutes`.
        mask_img (np.ndarray): A 3D array representing the brain mask for `tgt_image`, where brain
            regions are labelled 1 and non-brain regions are labelled 0. This is made necessary in
            order to save time during computation. 

    Returns:
        params_img (np.ndarray): A 4D array with RTM parameter fit results based on the supplied
            method.
    """
    bounds = None
    if "bounds" in analysis_kwargs:
        bounds = True
    analysis_func = get_rtm_method(method=method,bounds=bounds)
    img_dims = tgt_image.shape
    output_shape = get_rtm_output_size(method=method)
    params_img = np.zeros((img_dims[0], img_dims[1], img_dims[2], output_shape), float)

    for i in range(0, img_dims[0], 1):
        for j in range(0, img_dims[1], 1):
            for k in range(0, img_dims[2], 1):
                if mask_img[i,j,k]>0.5:
                    tac_vals = tgt_image[i, j, k, :]
                    voxel_uncertainties = weight_tac_decay(tac_durations_in_minutes=scan_timing.duration,
                                                           tac_vals=tac_vals,
                                                           half_life_in_minutes=half_life)
                    analysis_vals = analysis_func(tac_times_in_minutes=scan_timing.center,
                                                  ref_tac_vals=ref_tac_vals,
                                                  tgt_tac_vals=tac_vals,
                                                  uncertainties=voxel_uncertainties,
                                                  **analysis_kwargs)
                    params_img[i,j,k] = analysis_vals[0]

    return params_img

In [9]:
def get_frame_timing_info_for_metadata(metadata_path: str) -> ScanTimingInfo:
    r"""
    Extracts frame timing information and decay factors from a json metadata.
    Expects that the JSON metadata file has ``FrameDuration`` and ``DecayFactor`` or
    ``DecayCorrectionFactor`` keys.

    .. important::
        This function tries to infer `FrameTimesEnd` and `FrameTimesStart` from the frame durations
        if those keys are not present in the metadata file. If the scan is broken, this might generate
        incorrect results.


    Args:
        metadata_path (str): Path to the json file.

    Returns:
        :class:`ScanTimingInfo`: Frame timing information with the following elements:
            - duration (np.ndarray): Frame durations in seconds.
            - start (np.ndarray): Frame start times in seconds.
            - end (np.ndarray): Frame end times in seconds.
            - center (np.ndarray): Frame center times in seconds.
            - decay (np.ndarray): Decay factors for each frame.
    """
    _meta_data = safe_load_meta(input_metadata_file=metadata_path)
    frm_dur = np.asarray(_meta_data['FrameDuration'], float)
    try:
        frm_ends = np.asarray(_meta_data['FrameTimesEnd'], float)
    except KeyError:
        frm_ends = np.cumsum(frm_dur)
    try:
        frm_starts = np.asarray(_meta_data['FrameTimesStart'], float)
    except KeyError:
        frm_starts = np.diff(frm_ends)
    try:
        decay = np.asarray(_meta_data['DecayCorrectionFactor'], float)
    except KeyError:
        decay = np.asarray(_meta_data['DecayFactor'], float)
    try:
        frm_centers = np.asarray(_meta_data['FrameReferenceTime'], float)
    except KeyError:
        frm_centers = np.asarray(frm_starts + frm_dur / 2.0, float)

    return ScanTimingInfo(duration=frm_dur, start=frm_starts, end=frm_ends, center=frm_centers, decay=decay)

In [10]:
def get_frame_timing_info_for_nifti(image_path: str) -> ScanTimingInfo:
    r"""
    Extracts frame timing information and decay factors from a NIfTI image metadata.
    Expects that the JSON metadata file has ``FrameDuration`` and ``DecayFactor`` or
    ``DecayCorrectionFactor`` keys.

    .. important::
        This function tries to infer `FrameTimesEnd` and `FrameTimesStart` from the frame durations
        if those keys are not present in the metadata file. If the scan is broken, this might generate
        incorrect results.


    Args:
        image_path (str): Path to the NIfTI image file.

    Returns:
        scan_timing (:class:`ScanTimingInfo`): Frame timing information with the following elements:
            - duration (np.ndarray): Frame durations in seconds.
            - start (np.ndarray): Frame start times in seconds.
            - end (np.ndarray): Frame end times in seconds.
            - center (np.ndarray): Frame center times in seconds.
            - decay (np.ndarray): Decay factors for each frame.
    """
    _meta_data = load_metadata_for_nifti_with_same_filename(image_path=image_path)
    scan_timing = get_frame_timing_info_for_metadata(_meta_data)

    return scan_timing

In [None]:
class ReferenceTissueParametricImage:
    """
    Class for generating parametric images of 4D-PET images using reference tissue model (RTM)
    methods.

    Example:
        .. code-block:: python
            
            from petpal.kinetic_modeling import parametric_images
            
            rtm_parametric = ReferenceTissueParametricImage(reference_tac_path='/path/to/tac.tsv',
                                                            pet_image_path='/path/to/pet.nii.gz',
                                                            mask_image_path='/path/to/mask.nii.gz',
                                                            output_directory='/path/to/output,
                                                            output_filename_prefix='sub-001_mrtm2')
            rtm_parametric.run_parametric_analysis(method='mrtm2',
                                                   k2_prime=0.01,
                                                   t_thresh_in_mins=30)
            rtm_parametric.save_parametric_images()

    """
    def __init__(self,
                 reference_tac_path: str,
                 pet_image_path: str,
                 mask_image_path: str,
                 output_directory: str,
                 output_filename_prefix: str,
                 method: str='mrtm2'):
        """
        Initialize ReferenceTissueParametricImage with input values.

        Args:
            reference_tac_path (str): Path to the reference region TAC file.
            pet_image_path (str): Path to the 4D PET image on which kinetic analysis is performed.
            mask_image_path (str): Path to image that masks the brain in the same space as the PET
                image.
            output_directory (str): Path to folder where analysis is saved.
            output_filename_prefix (str): Prefix for output files saved after analysis.
            method (str): RTM method to run. Default 'mrtm2'.
        """
        self.reference_tac = TimeActivityCurveFromFile(tac_path=reference_tac_path)
        self.pet_image = safe_load_4dpet_nifti(pet_image_path)
        self.mask_image = safe_load_4dpet_nifti(mask_image_path)
        self.metadata = load_metadata_for_nifti_with_same_filename(pet_image_path)

        validate_two_images_same_dimensions(self.pet_image,self.mask_image,check_4d=False)

        self.output_directory = output_directory
        self.output_filename_prefix = output_filename_prefix
        self.method = method
        self.analysis_props = self.init_analysis_props(method)
        self.fit_results = None, None


    def get_time_dependent_properties(self):
        "Extract scan timing info and half life from the PET image metadata"
        self.half_life = get_half_life_from_meta(meta_data_file_path=self.metadata)
        self.scan_timing = get_frame_timing_info_for_metadata(metadata_path=self.metadata)


    def init_analysis_props(self, method: str) -> dict:
        r"""
        Initializes the analysis properties dict based on the specified RTM analysis method.

        Args:
            method (str): RTM analysis method. Must be one of 'srtm', 'frtm', 'mrtm-original',
                'mrtm' or 'mrtm2'.

        Returns:
            dict: A dictionary containing method-specific property keys and default values.

        Raises:
            ValueError: If input `method` is not one of the supported RTM methods.
        """
        common_props = {'MethodName': method.upper()}
        if method.startswith("mrtm"):
            props = {
                'BP': None,
                'k2Prime': None,
                'ThresholdTime': None,
                'Bounds': None,
                'StartFrameTime': None,
                'EndFrameTime' : None,
                'NumberOfPointsFit': None,
                'RawFits': None,
                **common_props
                }
        elif method.startswith("srtm") or method.startswith("frtm"):
            props = {
                'FitValues': None,
                'FitStdErr': None,
                **common_props
                }
        else:
            raise ValueError(f"Invalid method! Must be either 'srtm', 'frtm', 'srtm2', 'frtm2', "
                             f"'mrtm-original', 'mrtm' or 'mrtm2'. Got {method}.")
        return props


    def set_analysis_props(self,
                           props: dict,
                           bounds: Union[None, np.ndarray] = None,
                           k2_prime: float=None,
                           t_thresh_in_mins: float=None,
                           image_scale: float=None):
        """
        Set kwargs used for running parametric analysis.

        Args:
            rtm_kwargs (dict): Dictionary of kwargs fed into RTM analysis.
        """
        props['Bounds'] = bounds
        props['k2Prime'] = k2_prime
        props['ThresholdTime'] = t_thresh_in_mins
        props['ImageScale'] = image_scale
    # TODO: get scan timing from image bids meta not TAC
    # that way we get reference times and durations and half life
    def apply_rtm2_to_all_voxels_weight_decay(self,
                                              scan_timing: ScanTimingInfo,
                                              tgt_image: np.ndarray,
                                              ref_tac_vals: np.ndarray,
                                              mask_img: np.ndarray,
                                              half_life: float, 
                                              method: str = 'srtm2',
                                                **analysis_kwargs) -> np.ndarray:
        """
        Generates parametric images for 4D-PET data using the SRTM2 reference tissue method.

        Args:
            tac_times_in_minutes (np.ndarray): A 1D array representing the reference TAC and PET frame
                times in minutes.
            tgt_image (np.ndarray): A 4D array representing the 3D PET image over time.
                The shape of this array should be (x, y, z, time).
            ref_tac_vals (np.ndarray): A 1D array representing the reference TAC values. This array
                should be of the same length as `tac_times_in_minutes`.
            mask_img (np.ndarray): A 3D array representing the brain mask for `tgt_image`, where brain
                regions are labelled 1 and non-brain regions are labelled 0. This is made necessary in
                order to save time during computation. 

        Returns:
            params_img (np.ndarray): A 4D array with RTM parameter fit results based on the supplied
                method.
        """
        bounds = None
        if "bounds" in analysis_kwargs:
            bounds = True
        analysis_func = get_rtm_method(method=method,bounds=bounds)
        img_dims = tgt_image.shape
        output_shape = get_rtm_output_size(method=method)
        params_img = np.zeros((img_dims[0], img_dims[1], img_dims[2], output_shape), float)

        for i in range(0, img_dims[0], 1):
            for j in range(0, img_dims[1], 1):
                for k in range(0, img_dims[2], 1):
                    if mask_img[i,j,k]>0.5:
                        tac_vals = tgt_image[i, j, k, :]
                        voxel_uncertainties = weight_tac_decay(tac_durations_in_minutes=scan_timing.duration,
                                                            tac_vals=tac_vals,
                                                            half_life_in_minutes=half_life)
                        analysis_vals = analysis_func(tac_times_in_minutes=scan_timing.center,
                                                    ref_tac_vals=ref_tac_vals,
                                                    tgt_tac_vals=tac_vals,
                                                    uncertainties=voxel_uncertainties,
                                                    **analysis_kwargs)
                        params_img[i,j,k] = analysis_vals[0]

        return params_img

    def run_parametric_analysis(self,
                                bounds: Union[None, np.ndarray] = None,
                                k2_prime: float=None,
                                t_thresh_in_mins: float=None,
                                image_scale: float=1):
        """
        Run the analysis.

        Args:
            method (str): The method to be used in voxel-wise analysis. Currently only mrtm2 is
                implemented.
            bounds (Union[None, np.ndarray]): Bounds on fit parameters. See
                :py:func:`get_rtm_kwargs`. Default None.
            k2_prime (float): k2' value set for all voxel-wise analysis. Default None.
            t_thresh_in_mins (float): Threshold time after which kinetic parameters are fit.
                Default None.
        
        Returns:
            fit_results (np.ndarray, Tuple[np.ndarray, np.ndarray]): Kinetic parameters and
                simulated data returned as arrays. 
        """
        pet_np = self.pet_image.get_fdata()
        mask_np = self.mask_image.get_fdata()
        tac_times_in_minutes = self.reference_tac.tac_times_in_minutes
        ref_tac_vals = self.reference_tac.tac_vals
        method = self.method
        rtm_method = get_rtm_method(method)
        analysis_kwargs = get_rtm_kwargs(method=rtm_method,
                                         bounds=bounds,
                                         k2_prime=k2_prime,
                                         t_thresh_in_mins=t_thresh_in_mins)

        fit_results = self.apply_rtm2_to_all_voxels(tac_times_in_minutes=tac_times_in_minutes,
                                               tgt_image=pet_np * image_scale,
                                               ref_tac_vals=ref_tac_vals,
                                               mask_img=mask_np,
                                               method=method,
                                               **analysis_kwargs)
        self.fit_results = fit_results


    def save_parametric_images(self):
        """
        Save parametric images.
        """
        fit_image = self.fit_results
        pet_image = self.pet_image
        fit_nibabel = nibabel.nifti1.Nifti1Image(dataobj=fit_image,
                                                 affine=pet_image.affine,
                                                 header=pet_image.header)

        try:
            fit_image_path = os.path.join(self.output_directory,
                                    f"{self.output_filename_prefix}_desc-rtmfit_pet.nii.gz")
            nibabel.save(fit_nibabel,fit_image_path)
        except IOError as exc:
            print("An IOError occurred while attempting to write the NIfTI image files.")
            raise exc from None


    def save_analysis_properties(self):
        """
        Saves the analysis properties to a JSON file in the output directory.

        This method involves saving a dictionary of analysis properties, which include file paths,
        analysis method, start and end frame times, threshold time, number of points fitted, and 
        various properties like the maximum, minimum, mean, and variance of slopes and intercepts
        found in the analysis. These analysis properties are written to a JSON file in the output
        directory with the name following the pattern
        `{output_filename_prefix}-analysis-props.json`.

        Args:
            None

        Returns:
            None

        Raises:
            IOError: An error occurred accessing the output_directory or while writing to the JSON
            file.

        See Also:
            * :func:`save_analysis_properties`
        """
        analysis_props_file = os.path.join(self.output_directory,
                                           f"{self.output_filename_prefix}_desc-"
                                           f"{self.analysis_props['MethodName']}_props.json")
        with open(analysis_props_file, 'w', encoding='utf-8') as f:
            json.dump(obj=self.analysis_props, fp=f, indent=4)


    def __call__(self,
                 bounds: np.ndarray=None,
                 t_thresh_in_mins: float=None,
                 k2_prime: float=None,
                 image_scale: float=None):
        self.run_parametric_analysis(bounds=bounds,
                                     t_thresh_in_mins=t_thresh_in_mins,
                                     k2_prime=k2_prime,
                                     image_scale=image_scale)
        self.set_analysis_props(props=self.analysis_props,
                                bounds=bounds,
                                k2_prime=k2_prime,
                                t_thresh_in_mins=t_thresh_in_mins,
                                image_scale=image_scale)
        self.save_parametric_images()
        self.save_analysis_properties()
