In [None]:
from datetime import datetime

now = datetime.now()

current_time = now.strftime("%H:%M:%S")
print("Current Time =", current_time)

import sys
sys.path.insert(1, "/opt/animalfmritools")

from niworkflows.engine.workflows import (
    LiterateWorkflow as Workflow,
)
from nipype.interfaces import utility as niu
from nipype.pipeline import engine as pe

from animalfmritools.utils.data_grabber import (
    REVERSE_PE_MAPPING,
    PE_DIR_SCHEMA, 
    PE_DIR_FLIP,
    BidsReader, 
)
from animalfmritools.interfaces.rescale_nifti import RescaleNifti
from animalfmritools.interfaces.flip_nifti import FlipNifti
from animalfmritools.interfaces.evenify_nifti import EvenifyNifti

from animalfmritools.workflows.bold.boldref import init_bold_ref_wf
from animalfmritools.workflows.bold.sdc import init_bold_sdc_wf

from animalfmritools.workflows.registration.utils import init_itk_to_fsl_affine_wf
from animalfmritools.workflows.registration.utils import init_itk_to_fsl_warp_wf
from fmriprep.workflows.bold.registration import init_fsl_bbr_wf

from nipype.interfaces.fsl.maths import Threshold, ApplyMask, MeanImage, UnaryMaths
from nipype.interfaces.fsl import MCFLIRT, FLIRT, ApplyWarp
from nipype.interfaces.fsl.utils import Split
from nipype.interfaces.ants import N4BiasFieldCorrection, RegistrationSynQuick
from niworkflows.interfaces.nibabel import GenerateSamplingReference
from niworkflows.interfaces.confounds import NormalizeMotionParams
from nipype.interfaces.fsl import ConvertXFM, ConvertWarp
from nipype.interfaces.ants import ApplyTransforms

from pathlib import Path

"""
Functions
"""

def load_json_as_dict(json_path):
    import json
    assert json_path.exists(), f"{json_path} not found."
    with open(json_path) as json_file:
        metadata = json.load(json_file)

    return metadata


def _pick_rel(rms_files):
    return rms_files[-1]

def _jsonify(_dict):
    import json
    import os

    filename = "/tmp/confounds_metadata.json"

    with open(filename, "w") as f:
        json.dump(_dict, f)

    return filename


"""
Start
"""
study_id = 'MouseAD'
outdir = f"/opt/animalfmritools/animalfmritools/data/{study_id}"
deriv_dir = Path(f"/opt/animalfmritools/animalfmritools/data/{study_id}/derivatives/mousefmriprep")

bids_reader = BidsReader(outdir)
# get subject ids 
sub_ids = bids_reader.get_subjects()
sub_id = sub_ids[0]
# get session ids
ses_ids = bids_reader.get_sessions(sub_id)
ses_id = ses_ids[0]
RESCALE_FACTOR = 10

print(f"""Subject ID: {sub_id}
Session ID: {ses_id}
""")

anat = bids_reader.get_anat(sub_id)
bold_dict = bids_reader.get_bold_runs(sub_id, ses_id, ignore_tasks = [])
fmap_dict = bids_reader.get_fmap_runs(sub_id, ses_id)
template_path = "/opt/animalfmritools/animalfmritools/data/TMBTA_Brain_Template.nii.gz"
template_csf_path = "/opt/animalfmritools/animalfmritools/data/TMBTA_Ventricles.nii.gz"
template_gm_path = "/opt/animalfmritools/animalfmritools/data/TMBTA_Grey.nii.gz"
template_wm_path = "/opt/animalfmritools/animalfmritools/data/TMBTA_White_eroF.nii.gz"


wf = Workflow(name=f"scratch", base_dir=outdir)

"""
Set-up buffer nodes
"""

# Anat
from nipype.interfaces.ants import N4BiasFieldCorrection

anat_buffer = pe.Node(niu.IdentityInterface(["t2w"]), name = "anat_buffer")
anat_buffer.inputs.t2w = anat

rescale_anat = pe.Node(RescaleNifti(rescale_factor = RESCALE_FACTOR), name='rescale_anat')
n4_anat = pe.Node(N4BiasFieldCorrection(), name="n4_anat")
# fmt: off
wf.connect([
    (anat_buffer, rescale_anat, [("t2w", "nifti_path")]),
    (rescale_anat, n4_anat, [("rescaled_path", "input_image")])
])
# fmt: on

# Template
template_buffer = pe.Node(niu.IdentityInterface(["template", "gm", "wm", "csf"]), name = "template_buffer")
template_buffer.inputs.template = template_path
template_buffer.inputs.gm = template_gm_path
template_buffer.inputs.wm = template_wm_path
template_buffer.inputs.csf = template_csf_path

rescale_template = pe.Node(RescaleNifti(rescale_factor = RESCALE_FACTOR), name = 'rescale_template')
rescale_template_gm = pe.Node(RescaleNifti(rescale_factor = RESCALE_FACTOR), name = 'rescale_template_gm')
rescale_template_wm = pe.Node(RescaleNifti(rescale_factor = RESCALE_FACTOR), name = 'rescale_template_wm')
rescale_template_csf = pe.Node(RescaleNifti(rescale_factor = RESCALE_FACTOR), name = 'rescale_template_csf')

# fmt: off
wf.connect([
    (template_buffer, rescale_template, [("template", "nifti_path")]),
    (template_buffer, rescale_template_gm, [("gm", "nifti_path")]),
    (template_buffer, rescale_template_wm, [("wm", "nifti_path")]),
    (template_buffer, rescale_template_csf, [("csf", "nifti_path")]),
])
# fmt: on


# BOLD
bold_buffer_inputs = {}
bold_buffer = {}
for run_type, runs in bold_dict.items():
    
    n_runs = len(runs)
    if n_runs == 0:
        continue

    assert run_type in PE_DIR_FLIP.keys(), f"{run_type} not found"

    bold_buffer_inputs[run_type] = [f"bold_run_{str(ix).zfill(4)}" for ix in range(n_runs)]
    bold_buffer[run_type] = pe.Node(
        niu.IdentityInterface(bold_buffer_inputs[run_type]),
        name = f"bold_buffer_{run_type}",
    )
    
    if PE_DIR_FLIP[run_type]:

        for ix, run_path in enumerate(runs):
            flip_nifti = pe.Node(
                FlipNifti(nifti_path=run_path,), 
                name=f"bold_flip_{run_type}_{ix}"
            )
            evenify_nifti = pe.Node(
                EvenifyNifti(),
                name=f"bold_evenify_{run_type}_{ix}",
            )
            rescale_nifti = pe.Node(
                RescaleNifti(rescale_factor=RESCALE_FACTOR),
                name=f"bold_rescale_{run_type}_{ix}"
            )
            # fmt: off
            wf.connect([
                (flip_nifti, evenify_nifti, [("flipped_path", "nifti_path")]),
                (evenify_nifti, rescale_nifti, [("out_path", "nifti_path")]),
                (rescale_nifti, bold_buffer[run_type], [("rescaled_path", bold_buffer_inputs[run_type][ix])]),
            ])
            # fmt: on

    else:

        for ix, run_path in enumerate(runs):
            evenify_nifti = pe.Node(
                EvenifyNifti(nifti_path=run_path),
                name=f"bold_evenify_{run_type}_{ix}",
            )
            rescale_nifti = pe.Node(
                RescaleNifti(rescale_factor=10),
                name=f"bold_rescale_{run_type}_{ix}"
            )
            # fmt: off
            wf.connect([
                (evenify_nifti, rescale_nifti, [("out_path", "nifti_path")]),
                (rescale_nifti, bold_buffer[run_type], [("rescaled_path", bold_buffer_inputs[run_type][ix])]),
            ])
            # fmt: on

# FMAP
fmap_buffer_inputs = {}
fmap_buffer = {}
for run_type, runs in fmap_dict.items():
    
    n_runs = len(runs)
    if n_runs == 0:
        continue

    assert run_type in PE_DIR_FLIP.keys(), f"{run_type} not found"

    fmap_buffer_inputs[run_type] = [f"fmap_run_{str(ix).zfill(4)}" for ix in range(n_runs)]
    fmap_buffer[run_type] = pe.Node(
        niu.IdentityInterface(fmap_buffer_inputs[run_type]),
        name = f"fmap_buffer_{run_type}",
    )
    
    if PE_DIR_FLIP[run_type]:

        for ix, run_path in enumerate(runs):
            # Flip nifti
            flip_nifti = pe.Node(
                FlipNifti(nifti_path=run_path,), 
                name=f"fmap_flip_{run_type}_{ix}"
            )
            evenify_nifti = pe.Node(
                EvenifyNifti(),
                name=f"fmap_evenify_{run_type}_{ix}",
            )
            rescale_nifti = pe.Node(
                RescaleNifti(rescale_factor=RESCALE_FACTOR),
                name=f"fmap_rescale_{run_type}_{ix}"
            )
            # fmt: off
            wf.connect([
                (flip_nifti, evenify_nifti, [("flipped_path", "nifti_path")]),
                (evenify_nifti, rescale_nifti, [("out_path", "nifti_path")]),
                (rescale_nifti, fmap_buffer[run_type], [("rescaled_path", fmap_buffer_inputs[run_type][ix])]),
            ])
            # fmt: on

    else:

        for ix, run_path in enumerate(runs):
            evenify_nifti = pe.Node(
                EvenifyNifti(nifti_path=run_path),
                name=f"fmap_evenify_{run_type}_{ix}",
            )
            rescale_nifti = pe.Node(
                RescaleNifti(rescale_factor=10),
                name=f"fmap_rescale_{run_type}_{ix}"
            )
            # fmt: off
            wf.connect([
                (evenify_nifti, rescale_nifti, [("out_path", "nifti_path")]),
                (rescale_nifti, fmap_buffer[run_type], [("rescaled_path", fmap_buffer_inputs[run_type][ix])]),
            ])
            # fmt: on

"""
Set-up bold template (one template per PE direction)
- SDC unwarping will be estimated for the bold template
    - Select first bold run, 
    - extract first volume
    - find an reverse PE-epi volume
        - look for fmap, if nothing found look for reverse PE-bold runs, and extract first volume
    - use the reverse PE pairs to perform topup and obtain the displacement warp for the 1st volume
"""
session_bold_buffer = {}
for run_type, runs in bold_buffer_inputs.items():
    
    session_bold_buffer[run_type] = pe.Node(
        niu.IdentityInterface(
            ["sdc_warp", "sdc_affine", "sdc_bold", "distorted_bold"]
        ),
        name = f"session_bold_buffer_{run_type}",
    )
    
    sdc_buffer = pe.Node(
        niu.IdentityInterface(
            ["forward_pe', 'reverse_pe"]
        ),
        name = f"sdc_buffer_{run_type}",
    )
    # Get reverse PE direction
    reverse_run_type = REVERSE_PE_MAPPING[run_type]
    # Get boldref
    boldref_ses = runs[0]
    forward_pe_metadata = bold_dict[run_type][0]
    # Extract boldref
    session_boldref = init_bold_ref_wf(name = f"session_bold_reference_{run_type}")
    # fmt: off
    wf.connect([
        (bold_buffer[run_type], session_boldref, [(boldref_ses, "inputnode.bold")]),
        (session_boldref, sdc_buffer, [("outputnode.boldref", "forward_pe")]),
        (session_boldref, session_bold_buffer[run_type], [("outputnode.boldref", "distorted_bold")]),
    ])
    # fmt: on
    
    # Get boldref (reverse PE direction)
    try:
        reverse_pe = fmap_buffer_inputs[reverse_run_type][0]
        reverse_pe_metadata = fmap_dict[reverse_run_type][0]
        # fmt: off
        wf.connect([
            (fmap_buffer[reverse_run_type], sdc_buffer, [(reverse_pe, "reverse_pe")])
        ])
        # fmt: on
        
    except Exception as e1:
        try:
            reverse_pe = bold_buffer_inputs[reverse_run_type][0]
            reverse_pe_metadata = bold_dict[reverse_run_type][0]
            session_reverse_pe_boldref = init_bold_ref_wf(name = f"session_bold_reference_opposite_pe_{run_type}")
            # fmt: off
            wf.connect([
                (bold_buffer[reverse_run_type], session_reverse_pe_boldref, [(reverse_pe, "inputnode.bold")]),
                (session_reverse_pe_boldref, sdc_buffer, [("outputnode.boldref", "reverse_pe")])
            ])
            # fmt: on

        except Exception as e2:
            raise ValueError(f"A reverse PE run could not be found [{reverse_run_type}].")

    session_sdc = init_bold_sdc_wf(
        forward_pe_metadata,
        reverse_pe_metadata,
        name = f"session_bold_sdc_{run_type}"
    )
    # fmt: off
    wf.connect([
        (sdc_buffer, session_sdc, [("forward_pe", "inputnode.forward_pe")]),
        (sdc_buffer, session_sdc, [("reverse_pe", "inputnode.reverse_pe")]),
        (session_sdc, session_bold_buffer[run_type], [
            ("outputnode.sdc_warp", "sdc_warp"),
            ("outputnode.sdc_bold", "sdc_bold"),
            ("outputnode.sdc_affine", "sdc_affine"),
        ])
    ])
    # fmt: on

"""
Register all `sdc_bold`  to the first run_type
"""
session_template_buffer = pe.Node(
    niu.IdentityInterface(
        ["session_bold_template"]
    ),
    name = f"session_template_buffer",
)
session_template_reg_buffer = pe.Node(
    niu.IdentityInterface([f"aff_bold_to_boldtemplate_{run_type}" for run_type in session_bold_buffer.keys()]),
    name = f"session_template_reg_buffer"
)
first_key = next(iter(session_bold_buffer))

# fmt: off
wf.connect([
    (session_bold_buffer[first_key], session_template_buffer, [("sdc_bold", "session_bold_template")]),
])
# fmt: on
for session_ix, (run_type, session_node) in enumerate(session_bold_buffer.items()):
    if session_ix == 0:
        # fmt: off
        wf.connect([
            (session_bold_buffer[run_type], session_template_reg_buffer, [("sdc_affine", f"aff_bold_to_boldtemplate_{run_type}")])
        ])
        # fmt: on
    else:
        reg_bold_to_boldtemplate = init_fsl_bbr_wf(
            bold2t1w_dof = 6, 
            use_bbr = False,
            bold2t1w_init = "register",
            omp_nthreads=4,
            name = f"reg_sdc-bold_to_sdc-boldtemplate_{run_type}",
        )
        xfm_convert_itk_to_fsl = init_itk_to_fsl_affine_wf(name=f"itk_to_fsl_bold_to_boldtemplate_{run_type}")
        # fmt: off
        wf.connect([
            (session_template_buffer, reg_bold_to_boldtemplate, [("session_bold_template", "inputnode.t1w_brain")]),
            (session_bold_buffer[run_type], reg_bold_to_boldtemplate, [("sdc_bold", "inputnode.in_file")]),
            (reg_bold_to_boldtemplate, xfm_convert_itk_to_fsl, [("outputnode.itk_bold_to_t1", "inputnode.itk_affine")]),
            (session_bold_buffer[run_type], xfm_convert_itk_to_fsl, [("sdc_bold", "inputnode.source")]),
            (session_template_buffer, xfm_convert_itk_to_fsl, [("session_bold_template", "inputnode.reference")]),
            (xfm_convert_itk_to_fsl, session_template_reg_buffer, [("outputnode.fsl_affine", f"aff_bold_to_boldtemplate_{run_type}")])
        ])
        # fmt: on

"""
Regrid anat and template
"""
regrid_template = pe.Node(GenerateSamplingReference(), name="regrid_template")
binarize_template = pe.Node(UnaryMaths(operation = 'bin'), name = 'binarize_template')
regrid_template_gm = pe.Node(ApplyTransforms(transforms='identity'), name="regrid_template_gm")
regrid_template_wm = pe.Node(ApplyTransforms(transforms='identity'), name="regrid_template_wm")
regrid_template_csf = pe.Node(ApplyTransforms(transforms='identity'), name="regrid_template_csf")

regrid_t2w = pe.Node(GenerateSamplingReference(), name="regrid_t2w")

from nipype.interfaces.utility import Function
def merge_inputs(input_1, input_2, input_3):    
    return [input_1, input_2, input_3]

merge_template_tpms = pe.Node(
    Function(
        input_names = ["input_1", "input_2", "input_3"],
        output_names = ["output_list"],
        function = merge_inputs
    ),
    name="merge_template_tpms"
)



"""
Register session bold template to T2w
"""
def _get_split_volume(out_files, vol_id):
    return out_files[vol_id]
    
initreg_boldtemplate_to_t2w_boldref = pe.Node(Split(dimension="t", out_base_name="split_bold_"), name="initreg_boldtemplate_to_t2w_boldref")
initreg_boldtemplate_to_t2w_hmc = pe.Node(MCFLIRT(save_mats=False, save_plots=False, save_rms=False), name="initreg_boldtemplate_to_t2w_hmc")
initreg_boldtemplate_to_t2w_tmean = pe.Node(MeanImage(), name="initreg_boldtemplate_to_t2w_tmean")
initreg_boldtemplate_to_t2w_n4 = pe.Node(N4BiasFieldCorrection(), name="initreg_boldtemplate_to_t2w_n4")
initreg_boldtemplate_to_t2w_sdc = pe.Node(ApplyWarp(), name="initreg_boldtemplate_to_t2w_warp")
reg_boldtemplate_to_t2w = init_fsl_bbr_wf(
    bold2t1w_dof = 6, 
    use_bbr = False,
    bold2t1w_init = "register",
    omp_nthreads=4,
    name = f"reg_sdc-boldtemplate_to_t2w_{first_key}",
)
xfm_convert_itk_to_fsl_boldtemplate_to_t2w = init_itk_to_fsl_affine_wf(name=f"itk_to_fsl_boldtemplate_to_t2w")

session_bold_run_input = bold_buffer_inputs[first_key][0]
# fmt: off
wf.connect([
    (bold_buffer[first_key], initreg_boldtemplate_to_t2w_hmc, [(session_bold_run_input, "in_file")]),
    (bold_buffer[first_key], initreg_boldtemplate_to_t2w_boldref, [(session_bold_run_input, "in_file")]),
    (initreg_boldtemplate_to_t2w_boldref, initreg_boldtemplate_to_t2w_hmc, [(("out_files", _get_split_volume, 0), "ref_file")]),
    (initreg_boldtemplate_to_t2w_hmc, initreg_boldtemplate_to_t2w_tmean, [("out_file", "in_file")]),
    (initreg_boldtemplate_to_t2w_tmean, initreg_boldtemplate_to_t2w_n4, [("out_file", "input_image")]),
    (initreg_boldtemplate_to_t2w_n4, initreg_boldtemplate_to_t2w_sdc, [
        ("output_image", "in_file"),
        ("output_image", "ref_file"),
    ]),
    (session_bold_buffer[first_key], initreg_boldtemplate_to_t2w_sdc, [("sdc_warp", "field_file")]),
    (n4_anat, regrid_t2w, [("output_image", "fixed_image")]),
    (session_template_buffer, regrid_t2w, [("session_bold_template", "moving_image")]),
    (initreg_boldtemplate_to_t2w_sdc, reg_boldtemplate_to_t2w, [("out_file", "inputnode.in_file")]),
    (regrid_t2w, reg_boldtemplate_to_t2w, [("out_file", "inputnode.t1w_brain")]),
    (reg_boldtemplate_to_t2w, xfm_convert_itk_to_fsl_boldtemplate_to_t2w, [("outputnode.itk_bold_to_t1", "inputnode.itk_affine")]),
    (initreg_boldtemplate_to_t2w_sdc, xfm_convert_itk_to_fsl_boldtemplate_to_t2w, [("out_file", "inputnode.source")]),
    (regrid_t2w, xfm_convert_itk_to_fsl_boldtemplate_to_t2w, [("out_file", "inputnode.reference")]),
])
# fmt: on

"""
Register T2w to template
"""
# First mask t2w
mask_t2w_initreg = pe.Node(
    FLIRT(dof=9),
    name = "mask_t2w_initreg_template_to_t2w",
)
mask_t2w_genmask = pe.Node(
    Threshold(thresh=500),
    name = "mask_t2w_generate_mask"
)
mask_t2w = pe.Node(
    ApplyMask(),
    name = "mask_t2w_apply_mask"
)
# register t2w to template
initreg_t2w_to_template = pe.Node(
    FLIRT(
        dof=12,
        searchr_x = [-180, 180],
        searchr_y = [-180, 180],
        searchr_z = [-180, 180],
    ),
    name = "reg_affine_t2w_to_template"
)
reg_t2w_to_template = pe.Node(
    RegistrationSynQuick(transform_type = 'sr'),
    name = "reg_t2w_to_template"
)
xfm_convert_itk_to_fsl_t2w_to_template_affine = init_itk_to_fsl_affine_wf(name=f"itk_to_fsl_t2w_to_template_affine")
xfm_convert_itk_to_fsl_t2w_to_template_warp = init_itk_to_fsl_warp_wf(name=f"itk_to_fsl_t2w_to_template_warp")
init_xfm_t2w_to_template_affine = pe.Node(ConvertXFM(concat_xfm=True), name="itk_to_fsl_t2w_to_template_mergeaffines")
init_xfm_t2w_to_template_warp = pe.Node(ConvertWarp(relwarp=True), name="itk_to_fsl_t2w_to_template_createwarp")
apply_t2w_to_template = pe.Node(ApplyWarp(), name=f"trans_t2w_to_template")


# fmt: off
wf.connect([
    (rescale_template, regrid_template, [("rescaled_path", "fixed_image")]),
    (session_template_buffer, regrid_template, [("session_bold_template", "moving_image")]),
    (regrid_template, mask_t2w_initreg, [("out_file", "in_file")]),
    (regrid_t2w, mask_t2w_initreg, [("out_file", "reference")]),
    (mask_t2w_initreg, mask_t2w_genmask, [("out_file", "in_file")]),
    (mask_t2w_genmask, mask_t2w, [("out_file", "mask_file")]),
    (regrid_t2w, mask_t2w, [("out_file", "in_file")]),
    (mask_t2w, initreg_t2w_to_template, [("out_file", "in_file")]),
    (regrid_template, initreg_t2w_to_template, [("out_file", "reference")]),
    (initreg_t2w_to_template, reg_t2w_to_template, [("out_file", "moving_image")]),
    (regrid_template, reg_t2w_to_template, [("out_file", "fixed_image")]),
    (reg_t2w_to_template, xfm_convert_itk_to_fsl_t2w_to_template_affine, [("out_matrix", "inputnode.itk_affine")]),
    (initreg_t2w_to_template, xfm_convert_itk_to_fsl_t2w_to_template_affine, [("out_file", "inputnode.source")]),
    (regrid_template, xfm_convert_itk_to_fsl_t2w_to_template_affine, [("out_file", "inputnode.reference")]),
    (reg_t2w_to_template, xfm_convert_itk_to_fsl_t2w_to_template_warp, [("forward_warp_field", "inputnode.itk_warp")]),
    (regrid_template, xfm_convert_itk_to_fsl_t2w_to_template_warp, [("out_file", "inputnode.reference")]),
    (initreg_t2w_to_template, init_xfm_t2w_to_template_affine, [("out_matrix_file", "in_file")]),
    (xfm_convert_itk_to_fsl_t2w_to_template_affine, init_xfm_t2w_to_template_affine, [("outputnode.fsl_affine", "in_file2")]),    
    (init_xfm_t2w_to_template_affine, init_xfm_t2w_to_template_warp, [("out_file", "premat")]),
    (xfm_convert_itk_to_fsl_t2w_to_template_warp, init_xfm_t2w_to_template_warp, [("outputnode.fsl_warp", "warp1")]),
    (regrid_template, init_xfm_t2w_to_template_warp, [("out_file", "reference")]),
    (regrid_template, apply_t2w_to_template, [("out_file", "ref_file")]),
    (mask_t2w, apply_t2w_to_template, [("out_file", "in_file")]),
    (init_xfm_t2w_to_template_warp, apply_t2w_to_template, [("out_file", "field_file")]),
    (regrid_template, binarize_template, [("out_file", "in_file")]),
    (regrid_template, regrid_template_gm, [("out_file", "reference_image")]),
    (rescale_template_gm, regrid_template_gm, [("rescaled_path", "input_image")]),
    (regrid_template, regrid_template_wm, [("out_file", "reference_image")]),
    (rescale_template_wm, regrid_template_wm, [("rescaled_path", "input_image")]),
    (regrid_template, regrid_template_csf, [("out_file", "reference_image")]),
    (rescale_template_csf, regrid_template_csf, [("rescaled_path", "input_image")]),
    (regrid_template_gm, merge_template_tpms, [("output_image", "input_1")]),
    (regrid_template_wm, merge_template_tpms, [("output_image", "input_2")]),
    (regrid_template_csf, merge_template_tpms, [("output_image", "input_3")]),
])
# fmt: on

"""
Process each run
"""
for run_type, _bold_buffer in bold_buffer.items():
    
    for bold_ix, bold_input in enumerate(bold_buffer_inputs[run_type]):
            
        bold_path = bold_dict[run_type][bold_ix]
        metadata = load_json_as_dict(
            Path(str(bold_path).replace('.nii.gz', '.json'))
        )

        boldref = init_bold_ref_wf(name = f"{bold_input}_reference_{run_type}")
        hmc = pe.Node(MCFLIRT(save_mats=True, save_plots=True, save_rms=True), name=f"{bold_input}_hmc_{run_type}")
        normalize_motion = pe.Node(NormalizeMotionParams(format="FSL"), name=f"{bold_input}_normalize_motion_{run_type}")
        reg_bold_to_boldtemplate = init_fsl_bbr_wf(
            bold2t1w_dof = 6, 
            use_bbr = False,
            bold2t1w_init = "register",
            omp_nthreads=4,
            name = f"reg_{bold_input}_to_nosdc-boldtemplate_{run_type}",
        )
        xfm_convert_itk_to_fsl = init_itk_to_fsl_affine_wf(name=f"itk_to_fsl_{bold_input}_to_nosdc-boldtemplate_{run_type}")
    
        # fmt: off
        wf.connect([
            (_bold_buffer, boldref,[(bold_input, "inputnode.bold")]),
            (_bold_buffer, hmc,[(bold_input, "in_file")]),
            (boldref, hmc,[("outputnode.boldref", "ref_file")]),
            (hmc, normalize_motion,[("par_file", "in_file")]), # NOTE: par_file are rescaled by a factor of 10 due to the `rescale_bold` step
            (boldref, reg_bold_to_boldtemplate, [("outputnode.boldref", "inputnode.in_file")]),
            (session_bold_buffer[run_type], reg_bold_to_boldtemplate, [("distorted_bold", "inputnode.t1w_brain")]),
            (reg_bold_to_boldtemplate, xfm_convert_itk_to_fsl, [("outputnode.itk_bold_to_t1", "inputnode.itk_affine")]),
            (boldref, xfm_convert_itk_to_fsl, [("outputnode.boldref", "inputnode.source")]),
            (session_bold_buffer[run_type], xfm_convert_itk_to_fsl, [("distorted_bold", "inputnode.reference")]),
        ])
        # fmt: on
        
        """ Merge transforms
        1) HMC regs to boldref [run | nosdc] `hmc` [inputs: `mat_file`]
        2) reg boldref [run | nosdc] to boldref [session | nosdc] `xfm_convert_itk_to_fsl` [inputs: `fsl_affine`]
        3) sdc warp [session] `session_bold_buffer[run_type]` [inputs: `sdc_warp`]
        4) reg boldref [session | sdc] to boldref [MAIN - runtype | session | sdc] `session_template_reg_buffer` [inputs: `aff_bold_to_boldtemplate_{run_type}`]
        5) reg boldref [MAIN - runtype | session | sdc] to T2w `xfm_convert_itk_to_fsl_boldtemplate_to_t2w` [inputs: `fsl_affine`]
        6) reg T2w to ABI template
        """
        merge_4_5 = pe.Node(ConvertXFM(concat_xfm=True), name=f"merge_xfms_4-5_{bold_input}_{run_type}")
        merge_2_5 = pe.Node(ConvertWarp(output_type = "NIFTI_GZ", relwarp=True), name = f"merge_xfms_2-5_{bold_input}_{run_type}")
        merge_2_6 = pe.Node(ConvertWarp(output_type = "NIFTI_GZ", relwarp=True), name = f"merge_xfms_2-6_{bold_input}_{run_type}")

        # Apply 
        from animalfmritools.interfaces.apply_bold_to_anat import ApplyBoldToAnat
        apply_bold_to_template = pe.Node(
            ApplyBoldToAnat(debug=False),
            name=f"trans_{bold_input}_to_template_{run_type}",
        )
        
        # fmt: off
        wf.connect([
            (session_template_reg_buffer, merge_4_5, [(f"aff_bold_to_boldtemplate_{run_type}", "in_file")]),
            (xfm_convert_itk_to_fsl_boldtemplate_to_t2w, merge_4_5, [("outputnode.fsl_affine", "in_file2")]),
            (xfm_convert_itk_to_fsl, merge_2_5, [("outputnode.fsl_affine", "premat")]),
            (session_bold_buffer[run_type], merge_2_5, [("sdc_warp", "warp1")]),
            (merge_4_5, merge_2_5, [("out_file", "postmat")]),
            (regrid_t2w, merge_2_5, [("out_file", "reference")]),
            (merge_2_5, merge_2_6, [("out_file", "warp1")]),
            (init_xfm_t2w_to_template_warp, merge_2_6, [("out_file", "warp2")]),
            (regrid_template, merge_2_6, [("out_file", "reference")]),
            (hmc, apply_bold_to_template, [("mat_file", "hmc_mats")]),
            (_bold_buffer, apply_bold_to_template,[(bold_input, "bold_path")]),
            (merge_2_6, apply_bold_to_template,[("out_file", "bold_to_anat_warp")]),
            (regrid_template, apply_bold_to_template, [("out_file", "anat_resampled")]),
        ])
        # fmt: on

        from animalfmritools.workflows.bold.confounds import init_bold_confs_wf
        bold_confs_wf = init_bold_confs_wf(
            mem_gb=8,
            metadata=metadata,
            freesurfer=True,
            regressors_all_comps=False,
            regressors_dvars_th=1.5,
            regressors_fd_th=0.5,
            name=f"confounds_{bold_input}_{run_type}_wf",
        )
        bold_confs_wf.inputs.inputnode.skip_vols = 0
        bold_confs_wf.inputs.inputnode.t1_bold_xform = "identity"
        # connect
        # fmt: off
        wf.connect([
            (apply_bold_to_template, bold_confs_wf, [("t1_bold_path", "inputnode.bold")]),
            (binarize_template, bold_confs_wf, [("out_file", "inputnode.bold_mask")]),
            (binarize_template, bold_confs_wf, [("out_file", "inputnode.t1w_mask")]),
            (normalize_motion, bold_confs_wf, [("out_file", "inputnode.movpar_file")]),
            (hmc, bold_confs_wf, [(("rms_files", _pick_rel), "inputnode.rmsd_file")]),
            (merge_template_tpms, bold_confs_wf, [("output_list", "inputnode.t1w_tpms")]),
        ])
        # fmt: on

        from animalfmritools.workflows.derivatives.outputs import parse_bold_path, get_source_files, init_bold_preproc_derivatives_wf
        deriv_outputs = get_source_files(
            parse_bold_path(bold_path),
            deriv_dir
        )
        bold_deriv_wf = init_bold_preproc_derivatives_wf(
            deriv_outputs,
            name = f"derivatives_{bold_input}_{run_type}_wf",   
        )
        # fmt: off
        wf.connect([
            (boldref, bold_deriv_wf, [("outputnode.boldref", "inputnode.bold_ref")]),
            (apply_bold_to_template, bold_deriv_wf, [("t1_bold_path", "inputnode.bold_preproc")]),
            (bold_confs_wf, bold_deriv_wf, [
                (("outputnode.confounds_metadata", _jsonify), "inputnode.bold_confounds_metadata"),
                ("outputnode.confounds_file", "inputnode.bold_confounds"),
                ("outputnode.rois_plot", "inputnode.bold_roi_svg")
            ]),
        ])
        # fmt: on

wf.run()


now = datetime.now()
current_time = now.strftime("%H:%M:%S")
print("Current Time =", current_time)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import math
fig, axs = plt.subplots(ncols=2)
f=0.2
period = 1/f
TR = 1.64
duration = math.floor(4 * 60 / TR) * TR
n_tps = int(duration / TR)
t = np.linspace(0, duration, n_tps + 1)
y = np.sin(2*np.pi*f*t)
axs[0].scatter(t,y)
axs[1].scatter(t%period,y)