# TODO 
- wrap and document plot_later()
- add full range of regressors
- run whole first level analysis
- add proper markdown description

In [17]:
import sys
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

sys.path.append('/home/kmb/Desktop/Neuroscience/Projects/BONNA_decide_net/code')
from dn_utils.behavioral_models import load_behavioral_data                    
from dn_utils.glm_utils import Regressor, my_make_first_level_design_matrix

import nibabel as nib
from bids import BIDSLayout
from nilearn.plotting import plot_stat_map, plot_anat, plot_img, show
from nistats.first_level_model import FirstLevelModel
from nistats.reporting import plot_design_matrix
from nistats.thresholding import map_threshold
from nistats.design_matrix import make_first_level_design_matrix

In [22]:
def plot_later():
    # (optional) Display first-level results in the brain space
    _, threshold = map_threshold(
        z_map, 
        level=.05, 
        height_control='fpr')

    plot_stat_map(
        z_map, 
        bg_img=anat_img,
        threshold=3,
        display_mode='z')

### Load task onsets

Load behavioral data containg relevant task onsets. 

In [3]:
beh_dir = "/home/kmb/Desktop/Neuroscience/Projects/BONNA_decide_net/" \
          "data/main_fmri_study/sourcedata/behavioral"
beh, meta = load_behavioral_data(root=beh_dir)
n_subjects, n_conditions, n_trials, _ = beh.shape

Shape of beh array: (32, 2, 110, 21)
Conditions [(0, 'rew'), (1, 'pun')]
Columns: [(0, 'block'), (1, 'rwd'), (2, 'magn_left'), (3, 'magn_right'), (4, 'response'), (5, 'rt'), (6, 'won_bool'), (7, 'won_magn'), (8, 'acc_after_trial'), (9, 'onset_iti'), (10, 'onset_iti_plan'), (11, 'onset_iti_glob'), (12, 'onset_dec'), (13, 'onset_dec_plan'), (14, 'onset_dec_glob'), (15, 'onset_isi'), (16, 'onset_isi_plan'), (17, 'onset_isi_glob'), (18, 'onset_out'), (19, 'onset_out_plan'), (20, 'onset_out_glob')]


### Query neuroimaging dataset (path extraction)

Using BIDSLayout object query BIDS dataset to pull out necessary files.
- `anat_files`: sorted list of preprocessed T1w images
- `fmri_files`: list of two lists containing sorted (by subject number) paths to imaging files, first list corresponds to reward condition of PRL task and second list corresponds to punishment condition of PRL task
- `conf_files`: list of two lists containing sorted (by subject number) paths to confound files
- `mask_files`: brain mask files for fmri sequencnes

In [6]:
bids_dir = "/home/kmb/Desktop/Neuroscience/Projects/BONNA_decide_net/data/main_fmri_study"

layout = BIDSLayout(
    root=bids_dir,
    derivatives=True,
    validate=True,
    index_metadata=False
)

anat_filter = {
    "extension": [".nii.gz"],
    "space": "MNI152NLin2009cAsym",
    "suffix": "T1w",
    "desc": "preproc",
    "return_type": "filename"
}

fmri_filter = {
    "extension": [".nii", ".nii.gz"],
    "space": "MNI152NLin2009cAsym",
    "suffix": "bold",
    "desc": "preproc",
    "return_type": "filename"
}

conf_filter = {
    "extension": "tsv",
    "desc": "confounds",
    "return_type": "filename"
}

mask_filter = {
    "extension": [".nii.gz"],
    "space": "MNI152NLin2009cAsym",
    "desc": "brain",
    "suffix": "mask",
    "return_type": "filename"
}

anat_files = layout.get(**anat_filter)

fmri_files, conf_files, mask_files = [], [], []

for task_dict in [{"task": "prlrew"}, {"task": "prlpun"}]:
    fmri_filter.update(task_dict)
    conf_filter.update(task_dict)
    mask_filter.update(task_dict)
    fmri_files.append(layout.get(**fmri_filter))
    conf_files.append(layout.get(**conf_filter))
    mask_files.append(layout.get(**mask_filter))
    
# Load model-based regresssor(s)
epw_regressors = np.load('data/epw_regressors.npy')
erew_regressors = np.load('data/erew_regressors.npy')
pe_regressors = np.load('data/pe_regressors.npy')

### Single subject analysis

Here, first level GLM analysis is performed for each subject. For each imaging sequence following steps are applied:
1. relevant files are loaded: anatomical, functional, brain mask for functional file
2. `events` DataFrame containing left and right button presses (and misses) events is created. Event onset is assumed to be decision phase onset shifted by response time of participant. Button press events are modeled as zero-duration impulses
3. `confounds` dataframe is loaded and filtered. Included confounds are: six motion parameters and first five a_comp_cor regressors
4. glm model is evaluated (beta estimates for all regressors)
5. contrast of interest is defined. Here, contrast of interest, `left_minus_right`, is the difference between left button press and right button press events.
6. statistical t-map is calculated for defined contrast and saved under BIDS-like name in `bids_dir/derivatives/nistats/buttonpress` directory

In [18]:
# Root folder for storing output nistats derivatives
out_root = "/home/kmb/Desktop/Neuroscience/Projects/BONNA_decide_net/data/"\
           "main_fmri_study/derivatives/nistats"

# Directiories for different contrasts 
# TODO: remove it later
out_dir_btn = os.path.join(out_root, "buttonpress")
out_dir_pesign = os.path.join(out_root, "pesign")
out_dir_pefull = os.path.join(out_root, "pefull")
out_dir_pesurp = os.path.join(out_root, "pesurp")
out_dir_epw = os.path.join(out_root, "epw")
out_dir_erew = os.path.join(out_root, "erew")

# Specify GLM
fmri_glm = FirstLevelModel(
    t_r=2,
    noise_model='ar1',
    drift_model='cosine',
    period_cut=128,
    standardize=False,
    hrf_model='spm',
    smoothing_fwhm=6)

# Filtering confounds
confounds_relevant = ['a_comp_cor_00', 'a_comp_cor_01', 'a_comp_cor_02', 
                      'a_comp_cor_03', 'a_comp_cor_04', 'trans_x', 'trans_y', 
                      'trans_z', 'rot_x', 'rot_y', 'rot_z']

# Times of image acquisition in seconds
n_scans, t_r = 730, 2
frame_times = np.arange(n_scans) * t_r

In [None]:
for sub_idx in range(n_subjects):

    for con_idx in range(n_conditions):

        # Load subject data
        anat_img = nib.load(anat_files[sub_idx])
        fmri_img = nib.load(fmri_files[con_idx][sub_idx])
        fmri_glm.mask = nib.load(mask_files[0][0])
        confounds = pd.read_csv(conf_files[con_idx][sub_idx], sep='\t')
        confounds = confounds[confounds_relevant]
        confounds.index = frame_times # Standard time representation (in seconds)

        # Setup events
        resp_type = beh[sub_idx, con_idx, :, meta['dim4'].index('response')]
        onset_dec_phase = beh[sub_idx, con_idx, :, meta['dim4'].index('onset_dec')] 
        onset_dec = beh[sub_idx, con_idx, :, meta['dim4'].index('onset_dec')] + \
                    beh[sub_idx, con_idx, :, meta['dim4'].index('rt')]
        onset_out = beh[sub_idx, con_idx, :, meta['dim4'].index('onset_out')]

        pe_regressor = pe_regressors[sub_idx, con_idx, resp_type != 0]
        epw_regressor = epw_regressors[sub_idx, con_idx, resp_type != 0]
        erew_regressor = erew_regressors[sub_idx, con_idx, resp_type != 0]

        #---> choice phase regressors
        reg_lbp = Regressor('lbp', frame_times, onset_dec[resp_type==-1])
        reg_rbp = Regressor('rbp', frame_times, onset_dec[resp_type==1])
        reg_miss = Regressor('miss', frame_times, onset_dec[resp_type==0])
        reg_epw = Regressor(
            'epw', 
            frame_times, 
            onset_dec_phase[resp_type != 0],
            duration=beh[sub_idx, con_idx, resp_type != 0, meta['dim4'].index('rt')],
            modulation=epw_regressor
        )
        reg_erew = Regressor(
            'erew', 
            frame_times, 
            onset_dec_phase[resp_type != 0],
            duration=beh[sub_idx, con_idx, resp_type != 0, meta['dim4'].index('rt')],
            modulation=erew_regressor
        )

        #---> outcome phase regressors
        reg_pe_full = Regressor(
            'pe_full', frame_times, onset_out[resp_type != 0], 
            modulation=pe_regressor
        )
        reg_pe_sign = Regressor(
            'pe_sign', frame_times, onset_out[resp_type != 0],
            modulation=np.sign(pe_regressor)
        )
        reg_pe_surp = Regressor(
            'pe_surp', frame_times, onset_out[resp_type != 0], 
            modulation=np.abs(pe_regressor)
        )
        reg_pe_miss = Regressor(
            'pe_miss', frame_times, onset_out[resp_type == 0]
        )

        ############################################################################
        ############### paste rest of the code from below cells here ###############
        ############################################################################

### Expected probability of winning

In [24]:
dm, conditions = my_make_first_level_design_matrix(
    [reg_lbp, reg_rbp, reg_epw], confounds)

# Fit GLM
fmri_glm = fmri_glm.fit(fmri_img, design_matrices=dm)

# Define contrast
conditions['epw']

# Compute statistical map and save it
z_map = fmri_glm.compute_contrast(
    conditions['epw'],
    stat_type='t')

z_map_fname = f"sub-{meta['dim1'][sub_idx]}_" + \
              f"task-prl{meta['dim2'][con_idx]}_desc-epw_tmap"
# nib.save(z_map, os.path.join(out_dir_epw, z_map_fname))

### Button press contrast

In [None]:
dm, conditions = my_make_first_level_design_matrix(
    [reg_lbp, reg_rbp, reg_miss], confounds)

# Fit GLM
fmri_glm = fmri_glm.fit(fmri_img, design_matrices=dm)

# Define contrast
left_minus_right = conditions['lbp'] - conditions['rbp']

# Compute statistical map and save it
z_map = fmri_glm.compute_contrast(
    left_minus_right,
    stat_type='t')

z_map_fname = f"sub-{meta['dim1'][sub_idx]}_" + \
              f"task-prl{meta['dim2'][con_idx]}_desc-buttonpress_tmap"
# nib.save(z_map, os.path.join(out_dir_btn, z_map_fname))

### Full prediction error

In [None]:
dm, conditions = my_make_first_level_design_matrix(
    [reg_lbp, reg_rbp, reg_miss, reg_pe_full, reg_pe_miss],
    confounds
)

# Fit GLM
fmri_glm = fmri_glm.fit(fmri_img, design_matrices=dm)

# Compute statistical map and save it
z_map = fmri_glm.compute_contrast(
    conditions['pe_full'],
    stat_type='t')

z_map_fname = f"sub-{meta['dim1'][sub_idx]}_" + \
              f"task-prl{meta['dim2'][con_idx]}_desc-pefull_tmap"
# nib.save(z_map, os.path.join(out_dir_pefull, z_map_fname))

### Prediction error sign

In [None]:
dm, conditions = my_make_first_level_design_matrix(
    [reg_lbp, reg_rbp, reg_miss, reg_pe_sgn, reg_pe_miss],
    confounds
)

# Fit GLM
fmri_glm = fmri_glm.fit(fmri_img, design_matrices=dm)

# Compute statistical map and save it
z_map = fmri_glm.compute_contrast(
    conditions['pe_sgn'],
    stat_type='t')

z_map_fname = f"sub-{meta['dim1'][sub_idx]}_" + \
              f"task-prl{meta['dim2'][con_idx]}_desc-pesign_tmap"
# nib.save(z_map, os.path.join(out_dir_pesign, z_map_fname))

### Prediction error absolute value (surprise)

In [None]:
dm, conditions = my_make_first_level_design_matrix(
    [reg_lbp, reg_rbp, reg_miss, reg_pe_sur, reg_pe_miss],
    confounds
)

# Fit GLM
fmri_glm = fmri_glm.fit(fmri_img, design_matrices=dm)

# Define contrast
conditions['pe_sur']

# Compute statistical map and save it
z_map = fmri_glm.compute_contrast(
    conditions['pe_sur'],
    stat_type='t')

z_map_fname = f"sub-{meta['dim1'][sub_idx]}_" + \
              f"task-prl{meta['dim2'][con_idx]}_desc-pesurp_tmap"
# nib.save(z_map, os.path.join(out_dir_pesurp, z_map_fname))

## Explore correlation between regressors

In [15]:
# Calculate correlation between expected probability of winning and expected value
reg_corr = np.zeros((n_subjects, n_conditions))

for sub_idx in range(n_subjects):
    for con_idx in range(n_conditions):

        # Load decision phase regressors
        onset_dec = beh[sub_idx, con_idx, :, meta['dim4'].index('onset_dec')] + \
            beh[sub_idx, con_idx, :, meta['dim4'].index('rt')]
        resp_type = beh[sub_idx, con_idx, :, meta['dim4'].index('response')]
        epw_regressor = epw_regressors[sub_idx, con_idx, resp_type != 0]
        erew_regressor = erew_regressors[sub_idx, con_idx, resp_type != 0]
        
        # Grab correlation between decision phase regressors
        reg_epw = Regressor('epw', frame_times, onset_dec[resp_type!=0], 
                            modulation=epw_regressor)
        reg_erew = Regressor('erew', frame_times, onset_dec[resp_type!=0],
                             modulation=erew_regressor)
        reg_corr[sub_idx, con_idx] = Regressor.corrcoef(reg_epw, reg_erew)