In [1]:
import numpy as np
import nibabel
import json

In [2]:
class BaseImage(np.lib.mixins.NDArrayOperatorsMixin):

    def __init__(self, data):

        self.data = np.asarray(data)

    def __repr__(self):

        return f"{self.__class__.__name__}(data={self.data})"

    def __array__(self, dtype=None, copy=None):

        if copy is False:

            raise ValueError(

                "`copy=False` isn't supported. A copy is always created."

            )

        return self.data

    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):

        if method == '__call__':

            return self.__class__(ufunc(*inputs, **kwargs))

        else:

            return NotImplemented
    
    def __len__(self):

        return len(self.data)
    
    def __getitem__(self,a):

        return self.data[a]
    
    @property
    def shape(self):
        return self.data.shape

In [3]:
class HeaderImage(BaseImage):
    def __init__(self,data,header: dict={}):
        BaseImage.__init__(self,data)
        self.header = header
    
    def __repr__(self):

        return f"{self.__class__.__name__}(data={self.data},header={self.header})"


In [4]:
def load(image_path: str):
    image = nibabel.load(image_path)
    header_image = HeaderImage(data=image.get_fdata(),header=image.header)
    return header_image

In [5]:
ex_header_image = load('/export/scratch1/PETPAL/PETPAL312/sub-VATDYS017/sub-VATDYS017_suvr.nii.gz')

In [6]:
def load_metadata(meta_path) -> dict:
    with open(meta_path, 'r', encoding='utf-8') as meta_file:
        image_meta = json.load(meta_file)
    return image_meta

In [7]:
ex_bidsmeta = load_metadata('/data/norris/data1/data_archive/VATDYS/sub-VATDYS017/pet/sub-VATDYS017_pet.json')

In [8]:
ex_image = load('/data/norris/data1/data_archive/VATDYS/sub-VATDYS017/pet/sub-VATDYS017_pet.nii.gz')

In [10]:
class BidsMeta:
    def __init__(self):
        pass


In [60]:
class BidsImage(HeaderImage,BidsMeta):
    def __init__(self,data,header: dict={},bidsmeta: BidsMeta = None):
        HeaderImage.__init__(self,data,header)
        BidsMeta.__init__(self)
        self.bidsmeta = bidsmeta
    
    def __repr__(self):

        return f"{self.__class__.__name__}(data={self.data},header={self.header},bidsmeta={self.bidsmeta})"

In [65]:
def image_decorator(func,bids_image=None,*args,**kwargs):
    """
    Decorator to be applied to array to array functions that can apply class methods
    """
    if bids_image:
        def wrapper(*args,**kwargs):
            data = bids_image.data
            return func(input_image_array=data,*args,**kwargs)
        return wrapper
    else:
        return func

In [66]:
@image_decorator
def weighted_sum_computation(input_image_array: np.ndarray,
                             frame_duration: np.ndarray,
                             half_life: float,
                             frame_start: np.ndarray,
                             decay_correction: np.ndarray):
    """
    Weighted sum of a PET image based on time and re-corrected for decay correction.

    Args:
        image_frame_duration (np.ndarray): Duration of each frame in pet series
        half_life (float): Half life of tracer radioisotope in seconds.
        input_image_array (np.ndarray): 4D PET image series, as a data array.
        image_frame_start (np.ndarray): Start time of each frame in pet series,
            measured with respect to scan TimeZero.
        image_decay_correction (np.ndarray): Decay correction factor that scales
            each frame in the pet series. 

    Returns:
        image_weighted_sum (np.ndarray): 3D PET image computed by reversing decay correction
            on the PET image series, scaling each frame by the frame duration, then re-applying
            decay correction and scaling the image to the full duration.

    See Also:
        * :meth:`petpal.image_operations_4d.weighted_series_sum`: Function where this is implemented.

    """
    decay_constant = np.log(2.0) / half_life
    image_total_duration = np.sum(frame_duration)
    total_decay = decay_constant * image_total_duration
    total_decay /= 1.0 - np.exp(-1.0 * decay_constant * image_total_duration)
    total_decay /= np.exp(-1 * decay_constant * frame_start[0])
    
    pet_series_scaled = input_image_array[:, :, :] * frame_duration / decay_correction
    pet_series_sum_scaled = np.sum(pet_series_scaled, axis=3)
    image_weighted_sum = pet_series_sum_scaled * total_decay / image_total_duration
    return image_weighted_sum

In [12]:
class PetImage4d(BidsImage):
    def __init__(self,data,header: dict={},bidsmeta: BidsMeta = None):
        BidsImage.__init__(self,data,header,bidsmeta)
    
    def __repr__(self):
        return f"{self.__class__.__name__}(data={self.data},header={self.header},bidsmeta={self.bidsmeta})"
    
    @property
    def half_life(self):
        return 6.58404*1000 # TODO: figure out best way to store values and convert from bidsmeta.TracerRadionuclide
    
    def validate_required_fields(self):
        # idea: useful but might be better elsewhere
        # go through fields required for PET and validate at least those used by petpal software
        pass

In [13]:
def load_bidsmeta(meta_path) -> BidsMeta:
    with open(meta_path, 'r', encoding='utf-8') as meta_file:
        image_meta = json.load(meta_file)
    bidsmeta = BidsMeta()
    for key in image_meta.keys():
        bidsmeta.__setattr__(key.lower(),image_meta[key])
    return bidsmeta

In [67]:
ex_bidsmeta = load_bidsmeta('/data/norris/data1/data_archive/VATDYS/sub-VATDYS017/pet/sub-VATDYS017_pet.json')

In [68]:
bids2 = BidsImage(data=ex_image.data,header=ex_image.header,bidsmeta=ex_bidsmeta)

In [70]:
wss = weighted_sum_computation(
    input_image_array=bids2,
    frame_duration=bids2.bidsmeta.frameduration,
    half_life=6.58404*1000,
    frame_start=bids2.bidsmeta.frametimesstart,
    decay_correction=bids2.bidsmeta.decaycorrectionfactor
)

In [72]:
wss[50,50,32]

np.float64(8795.447693847533)

In [73]:
wss2 = weighted_sum_computation(
    bids2,
    frame_duration=bids2.bidsmeta.frameduration,
    half_life=6.58404*1000,
    frame_start=bids2.bidsmeta.frametimesstart,
    decay_correction=bids2.bidsmeta.decaycorrectionfactor
)

In [74]:
wss2[50,50,32]

np.float64(8795.447693847533)

In [19]:
bids3 = bids2.data * 5

In [20]:
bids3

array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
    

In [21]:
np.std(bids2)

np.float64(8351.391213989684)

In [49]:
# defining a decorator
def hello_decorator(func):
    # inner1 is a Wrapper function in 
    # which the argument is called
    
    # inner function can access the outer local
    # functions like in this case "func"
    def inner1(*args):
        print(func)
        print("Hello, this is before function execution")

        # calling the actual function now
        # inside the wrapper function.
        func(*args)

        print("This is after function execution")
        
    return inner1


# defining a function, to be called inside wrapper
def function_to_be_used():
    print("This is inside the function !!")


# passing 'function_to_be_used' inside the
# decorator to control its behaviour
function_to_be_used = hello_decorator(function_to_be_used)


# calling the function
#function_to_be_used()

In [46]:
@hello_decorator
def function_to_be_used2(arg: str):
    print(f"This is an arg: {arg}")

In [47]:
function_to_be_used2('bello')

<function function_to_be_used2 at 0x7f0687778ea0>
Hello, this is before function execution
This is an arg: bello
This is after function execution
