# Setup

In [1]:
from notebook_viewer_functions import *
from functions import *
from scivol import *
import numpy as np
import json
import ants
import gzip
import matplotlib.pyplot as plt
from ipywidgets import interact

proj_root = parent_directory()
print(f"project root: {proj_root}")
t1_input_filepath = os.path.join(proj_root, "media/sub-01/anat/sub-01_T1w.nii.gz")
bold_stim_filepath = os.path.join(proj_root, "media/sub-01/func/sub-01_task-emotionalfaces_run-1_bold.nii.gz")
bold_rest_filepath = os.path.join(proj_root, "media/sub-01/func/sub-01_task-rest_bold.nii.gz")
mni_anat_filepath =  os.path.join(proj_root, "templates/mni_icbm152_t1_tal_nlin_sym_09a.nii")
mni_mask_filepath = os.path.join(proj_root, "templates/mni_icbm152_t1_tal_nlin_sym_09a_mask.nii")
events_tsv_path = os.path.join(proj_root, "media/sub-01/func/task-emotionalfaces_run-1_events.tsv")
stimulus_image_path = "/Users/joachimpfefferkorn/repos/emotional-faces-psychopy-task-main/emofaces/POFA/fMRI_POFA"
log_path = "/Users/joachimpfefferkorn/repos/emotional-faces-psychopy-task-main/emofaces/data/01-subject_emofaces1_2019_Aug_14_1903.log"

raw_t1_img = ants.image_read(t1_input_filepath)
raw_stim_bold = ants.image_read(bold_stim_filepath)
raw_rest_bold_img = ants.image_read(bold_rest_filepath)
mni_img = ants.image_read(mni_anat_filepath)
mni_mask_img = ants.image_read(mni_mask_filepath)

project root: /Users/joachimpfefferkorn/repos/neurovolume


## Functions

In [2]:
def compare_bold_alignment(bold_seq_vol: np.ndarray, anat_vol: np.ndarray):
    #Just doing the Z dimension for now
    #No time, just a BOLD frame
    def x_coord(slice_idx, frame_idx, opacity):
        fig, axes = plt.subplots(1,3, figsize=(15,5))

        axes[0].imshow(bold_seq_vol[:,:,slice_idx, frame_idx], cmap='hot')
        axes[0].set_title('BOLD')

        axes[1].imshow(anat_vol[:,:,slice_idx], cmap='gray')
        axes[1].set_title('Anatomy')

        axes[2].imshow(anat_vol[:,:,slice_idx], cmap='gray')
        axes[2].imshow(bold_seq_vol[:,:,slice_idx, frame_idx], cmap='hot', alpha=opacity)
        axes[2].set_title('Overlay')
    interact(x_coord, slice_idx=(0, anat_vol.shape[2]-1), frame_idx=(0, bold_seq_vol.shape[3]-1),opacity=(0, 1.0))

In [3]:
def explore_fMRI(ants_img: ants.core.ants_image.ANTsImage,
                volume_override = "NULL",
                 dim="x", events_tsv="NULL",
                 cmap='nipy_spectral'):
    if type(volume_override) == np.ndarray:
        vol = volume_override
    else:
        vol = ants_img.numpy()
    
    def dim_to_indexed(dim, slice, frame):
        match dim:
            case "x":
                return vol[slice,:,:,frame]
            case "y":
                return vol[:,slice,:,frame]
            case "z":
                return vol[:,:,slice,frame]

    def plot(slice, frame):
        second = float(frame * ants_img.spacing[3])
        plt.figure()
        plt.imshow(dim_to_indexed(dim, slice, frame), cmap=cmap)
        plt.show()
        present_event = "No event file"
        if events_tsv != "NULL":
            for event in events_tsv.split("\n"):
                info = event.split("	")
                if info[0].isdigit() and info[1].isdigit():
                    if float(info[0]) <= second < float(info[0] + info[1]):
                        present_event = info[2]
            print(present_event)

    frame_slider = (0, (vol.shape[3]-1))
    match dim:
        case "x":
            interact(plot, slice=(0, vol.shape[0]-1), frame=frame_slider)
        case "y":
            interact(plot, slice=(0, vol.shape[1]-1), frame=frame_slider)
        case "z":
            interact(plot, slice=(0, vol.shape[2]-1), frame=frame_slider)

## Loading Data

In [4]:
bold_image = ants.image_read(bold_stim_filepath)
t1_image = ants.image_read(t1_input_filepath)

# Previous Failures
These are some things I've tried before, as an explanation as to why I'm writing out these large functions for stuff that should be just built into ANTs

## Canonical Way, register all frames within one function

The following does not (it crashes the kernel):

````python
bold_registered = ants.apply_transforms(
    fixed=t1_image,
    moving=bold_image,
    transformlist=registration['fwdtransforms'],
    interpolator='linear',
    imagetype=3
)
````

I suspect this has to do with a glitch on the image type 3 as this was also an issue yesterday. Perhaps open an issue on github

As a workaround, let's use a loop to cycle through all the 4th dimensional stuff.

Given the computational lift of registration, let's write them to a list

## Gathering the Transforms, then registering
The following code proved problematic:
````python
registrations = []
for frame in range(bold_image.shape[3]):
    print(f"frame {frame}/{bold_image.shape[3]}")
    print(" creating bold frame")
    bold_frame = ants.from_numpy(bold_image.numpy()[:,:,:,frame])

    print(" Registering bold frame to T1 image")
    registration = ants.registration(
    fixed=t1_image,
    moving=bold_frame,
    type_of_transform='Rigid'  # You can also use 'Affine' or 'SyN' for deformable registration
    )
    registrations += registration #This is incorrect! It's just adding the var names
````
When trying to apply these registrations, I got strange behavior. I believe this is because these `registration` values might be getting garbled behind the scenes somewhere. To get around this, let's see if one big function that gathers registrations and applies them does the trick!

Also, perhaps you could have made this a dictionary `img:transforms`

# 4D Registration

First, here's our main alignment function

In [None]:
def seq_to_4D(bold_frames: list, data_template_idx=0, template_img_idx=0, time_starts = 0.0, time_spacing=2.0):
    """
    Takes a list of BOLD frames and returns an ANTs image
    
    arguments:
    ---------
    template_img_idx : int
        the frame from which you want to grab all the ANTs image data (spacing, etc)

    time_starts : the starting origin for the temporal dimension

    time_spacing : float
        the temporal spacing, in terms of seconds per frame
        Set as 2.0 for default to match the dataset I'm currently using
    """

    if 0 > data_template_idx > len(bold_frames):
        print("Data template index is not valid. Setting to first frame")
        data_template_idx = 0
    t = bold_frames[template_img_idx] #t is for template
    combined_origin = (t.origin[0], t.origin[1], t.origin[2], time_starts)
    combined_spacing = (t.spacing[0], t.spacing[1], t.spacing[2], time_spacing)
    # setting direction to just a 4D identity matrix. Not sure if this is correct but I think it is?

    np_list = [frame.numpy()[:-1] for frame in bold_frames]
    data = np.stack(np_list, axis=3)
    combined_bold = ants.from_numpy(data, origin=combined_origin, spacing=combined_spacing, direction=np.eye(4))
    return combined_bold



def align_stabilized_bold_to_anat(stab_bold_img, t1_img):
    """"
    This function aligns an already stabilized BOLD image to a T1 anatomy image
    by aligning the mean of the BOLD to the T1 anatomy
    """
    print("Aligning stabilized bold to anat\nEstablishing temporal mean")


    first_frame = ants.from_numpy(stab_bold_img.numpy()[:,:,:,0], spacing=stab_bold_img.spacing[:3])

    frame_registration = ants.registration(
        fixed=t1_image,
        moving=first_frame,#temporal_mean,
        type_of_transform="Rigid", #haven't tried this one in earnest, tbh    
    )
    registered_frames = []
    print("Applying transformations to frames")
    for frame in range(stab_bold_img.shape[3]):
        print(f"     frame{frame}/{stab_bold_img.shape[3]}\n        creating bold frame")
        bold_frame = ants.from_numpy(stab_bold_img.numpy()[:,:,:,frame],
                                     spacing=stab_bold_img.spacing[:3]) #again, assuming direction defaults to identity matrix
        print("     Applying frame transformation")
        registered_frame = ants.apply_transforms(
            fixed=t1_image,
            moving=bold_frame,
            transformlist=frame_registration['fwdtransforms'], #perhaps this could/should also be where the motin correction goes?
            interpolator='linear'
        )
        #return bold_frame Returning the bold frame here does produce a non garbled result...
        print("     adding registered frame to list")
        registered_frames.append(registered_frame)
    print("Creating 4D numpy vol from list of 3D ANTs imgs")
    data = np.stack([frame.numpy() for frame in registered_frames], axis=3)
    print("Creating 4D bold image from numpy vol")
    registered_bold_img = ants.from_numpy(data, origin=stab_bold_img.origin, spacing=stab_bold_img.spacing)
    return registered_bold_img, data, registered_frames, first_frame #returning these to debug


#It is in fact garbled at the registered_bold_img level
#debug attempts:
# [x] rigid transform
#   The transform is now offset, and still garbled
# [x] Test temporal mean
#   It looks fine
# [x] frame registration for first frame, not temporal mean
#       This actually works!!!
# [ ] frame registration for each frame

#theres probably a way to dry align_stabilized_bold_to_anat and seq_to_4D. lots of similar logic. Should I DRY?
#I'm thinking at the very least TODO an "extrtact 3D" frame function

In order to test this, let's build a very short slice of the BOLD volume

In [31]:
sliced = bold_image.numpy()[:, :, :, :3]
bold_truncated_img = ants.from_numpy(sliced, spacing=bold_image.spacing, origin=bold_image.origin, direction=bold_image.direction)
#hopefully the above args match all the metadata. I've printed to check but there might be some stuff hidden away

Next, let's motion correct it

This returns a dictionary containing the image

In [32]:
stabilized = ants.motion_correction(bold_truncated_img)

And align it

In [33]:
aligned_bold_seq, numpy_data, frames_list, first_frame = align_stabilized_bold_to_anat(stabilized['motion_corrected'], t1_image)

Aligning stabilized bold to anat
Establishing temporal mean
 Creating frame registration
Applying transformations to frames
     frame0/3
        creating bold frame
     Applying frame transformation
     adding registered frame to list
     frame1/3
        creating bold frame
     Applying frame transformation
     adding registered frame to list
     frame2/3
        creating bold frame
     Applying frame transformation
     adding registered frame to list
Creating 4D numpy vol from list of 3D ANTs imgs
Creating 4D bold image from numpy vol


In [None]:
explore_3D_vol(first_frame.numpy())

interactive(children=(IntSlider(value=31, description='slice', max=63), Output()), _dom_classes=('widget-inter…

Looks like this is garbled on the numpy level:

In [None]:
explore_3D_vol(frames_list[0].numpy())

interactive(children=(IntSlider(value=255, description='slice', max=511), Output()), _dom_classes=('widget-int…

In [36]:
explore_fMRI(aligned_bold_seq)

interactive(children=(IntSlider(value=255, description='slice', max=511), IntSlider(value=1, description='fram…

In [None]:
compare_bold_alignment(aligned_bold_seq.numpy(), t1_image.numpy())

There are certainly some motion correction jitters. I have some theories on why:

- I couldn't add in the spacing and origin as the metadata while using `ants.from_numpy()` as it crashed the program

- I need to explicitly run motion correction (the alignment somehow doesn't account for it)

That being said, I feel like it's good enough for a proof of concept

# Let's run a (placehold) subtraction method on these

In [None]:
def subtract_BOLD_movement(bold_img):
    """
    Each frame shows the difference between it and the previous frame. First frame is initialized at zero. 
    """
    origin =(bold_img.origin[0], bold_img.origin[1], bold_img.origin[2], bold_img.origin[3])
    spacing = bold_img.spacing
    direction = bold_img.direction
    bold_np = np.empty_like(bold_img.numpy())

    print(f"origin {origin}\nspacing {spacing}\ndirection {direction}")
    print(bold_np.shape)
    for frame in range(1, bold_np.shape[3]):
        bold_np[:,:,:,frame] = np.absolute(bold_np[:,:,:,frame] - bold_np[:,:,:,frame - 1])
    #output = ants.from_numpy(bold_np, origin=origin, spacing=spacing, direction=direction, has_components=True) #not sure about this bool
    #return output
    return bold_np #again with the crashing

In [None]:
#bold_movements_np = subtract_BOLD_movement(aligned_bold_seq) #This crashed the kernel

In [None]:
print(bold_movements_img)

#explore_fMRI(bold_movements_img)

In [None]:
compare_bold_alignment(t1_image.numpy(), bold_movements_img.numpy())