# Import modules

In [None]:
from os.path import abspath, join, pardir
from nilearn import datasets
from roi_mean import roi_mean_interface
from bids.layout import BIDSLayout
from nipype.pipeline import Node, MapNode, Workflow
from nipype.interfaces.io import DataSink, DataGrabber
from nipype.algorithms.confounds import TSNR
from nipype.interfaces.utility import Function, IdentityInterface
from nilearn.input_data import NiftiLabelsMasker
from nilearn import plotting
from nipype.interfaces import fsl

# Define variables for data and atlas

In [None]:
dataset_sub = datasets.fetch_atlas_harvard_oxford('sub-maxprob-thr25-2mm')

atlas_filename_sub = dataset_sub.maps
labels_sub = dataset_sub.labels

project_path = abspath("tSNR_ROI/data")

layout = BIDSLayout(project_path)

## Show path to output directory

In [None]:
 abspath(join(project_path, pardir,"output"))

## Show atlas labels

In [None]:
dataset_sub.labels

## show runs

In [None]:
layout.get_runs()

### Fetch files by subject_id and run_id
        subject_id and run_id are the plain ids as you get them from `BIDSlayout.get_subjects()`
        
    Parameters
    ---------
    layout: BIDSLayout
        layout to use
    subject_id: str
        SubjectID without sub-
    run_id: str
        RunID without run-
        
    Returns
    -------
    preprocessed,preprocessed_anat, brainmask, subject_id, run_id

In [None]:
def get_files(layout, subject_id, run_id):
    brainmask = layout.get(type="brainmask", return_type="file", subject=subject_id, run=run_id)
    preprocessed = layout.get(type="preproc", return_type="file", subject=subject_id, run=run_id)[0]
    all_preprocs = layout.get(subject=subject_id, type="preproc", return_type="file", modality="anat")
    preproc_anat = [f for f in all_preprocs if "MNI152NLin2009cAsym" in f][0]
    print(brainmask)
    print(preprocessed)
    print(preproc_anat)  
    
    return preprocessed, preproc_anat, brainmask, subject_id, run_id

# Get Data

### Setup Node identitysource

In [None]:
identitysource = Node(IdentityInterface(fields=["subject_id", "run_id"]), name="identitysource")

identitysource.iterables= [('subject_id', layout.get_subjects()),
                           ('run_id', layout.get_runs())]

### Setup Node BIDSDataGrabber

In [None]:
BIDSDataGrabber = Node(Function(function=get_files,  input_names=["layout", "subject_id", "run_id",],
                                       output_names=["preproc", "preproc_anat", "brainmask", "subject_id", "run_id"]), name="BIDSDataGrabber")
BIDSDataGrabber.inputs.layout = layout

# Processing Nodes

### Setup Node FLIRT
#### Register T1w Image from fmriprep output to MNI152_T1_2mm_brain.nii.gz template

In [None]:
flirt = Node(fsl.FLIRT(bins=640, cost_func='mutualinfo'), name="flirt")
flirt.inputs.reference = abspath('MNI152_T1_2mm_brain.nii.gz')
flirt.inputs.output_type = "NIFTI_GZ"

### Setup Node tSNR
#### Calculate tSNR Map

In [None]:
tsnr = Node(TSNR(regress_poly=2), name='tsnr' )

### Setup Node Register_tsnr
#### Register tSNR Map to MNI152_T1_2mm_brain - template using matrix file from FLIRT

In [None]:
register_tsnr = Node(fsl.ApplyXFM(), name="register_tsnr")
register_tsnr.inputs.reference = abspath('MNI152_T1_2mm_brain.nii.gz')
register_tsnr.inputs.apply_xfm = True

### Set Atlas for mean calculation

In [None]:
roi_mean.inputs.roi_file = dataset_sub.maps
# Specify ROI labels: roi_mean.inputs.roi_label = [1,2,3,4,5]
# if disabled = use all labels

### Setup Node roi_mean
#### calculate mean values using function from roi_mean.py

In [None]:
roi_mean = Node(roi_mean_interface, name="roi_mean")

### Setup Node datasink

In [None]:
datasink = Node(DataSink(), name="datasink")
datasink.inputs.base_directory = abspath(join(project_path, pardir,"output"))

### Connect Nodes, Setup Workflow

In [None]:
tsnr_wf = Workflow(name="TSNR_calculation")
tsnr_wf.base_dir = abspath('rois_sub')
tsnr_wf.connect([(identitysource, BIDSDataGrabber, [('subject_id','subject_id'),
                                                    ('run_id','run_id')]), 
                 (BIDSDataGrabber, tsnr, [('preproc', 'in_file')]),
                 (BIDSDataGrabber, flirt, [('preproc_anat', 'in_file')]),
                 (flirt, register_tsnr, [('out_matrix_file', 'in_matrix_file')]),
                 (tsnr, register_tsnr, [('tsnr_file', 'in_file')] ),
                 (register_tsnr, roi_mean, [('out_file', 'data_file')]),       
                 (BIDSDataGrabber, datasink, [('subject_id', 'container')]),
                 
                 (register_tsnr, datasink, [('out_file','tsnr')]),
                 (roi_mean, datasink, [('out_file', 'tsnr.mean')]),]
         )
         
tsnr_wf.write_graph(graph2use='colored', format='svg', simple_form=True)

### Run the workflow

In [None]:
tsnr_wf.run()