In [None]:
import numpy as np
import pandas as pd
from os import path
from collections import defaultdict
from scipy.io import loadmat
from nilearn.glm.first_level import FirstLevelModel

# imports for data loading, processing and analysis
from utils.io import get_args, get_firstlevel_dir, save_args, load_data
from utils.preprocessing import set_get_timeseries, build_sub_run_df, convert_df_to_desMat, add_VTC_get_breaks, exclude_irrelevant_VTC_TRs, check_dataset2_censored_trs
from utils.nilearn_analysis import get_contrast_info, fit_edge_flm, compute_edge_contrast

In [None]:
# Set rng for replication
np.random.seed(2022)

In [None]:
args = get_args(['-dataset', '2'])

# Set up directories, load data
results_dir = get_firstlevel_dir(args)
save_args(args, results_dir)

ridx, cidx = np.tril_indices(args.nROI, -1)
get_timeseries = set_get_timeseries(args, ridx, cidx)  # set the type of timeseries to use, ROI or edge

contrasts, contrast_cols = get_contrast_info(args)

timeseries_data, sub_file_map, sublist, ntr_tossed = load_data(args)

In [None]:
# PREPROCESS BEHAV DATA, MAKE DESIGN MATRICES
des_mat_dict = defaultdict(list)
fitable_timeseries = defaultdict(list)
for sidx, sub in enumerate(sublist):
    sub_files = sub_file_map[sidx]

    for runidx, sub_file in enumerate(sub_files):
        sub_trs = np.arange(timeseries_data[sidx][runidx].shape[0])*args.t_r
        
        sub_run_df = build_sub_run_df(loadmat(sub_file), ntr_tossed, args)
        
        sub_run_df, (break_onsets, break_offsets, break_durations) = add_VTC_get_breaks(sub_run_df, args)

        sub_run_desmat = convert_df_to_desMat(
                sub_run_df,
                sub_trs,
                model=args.model,
                VTC_shift = args.VTC_shift,
                break_durations=break_durations,
            )

        # Include a session if a subject experienced all trial types for the contrast
        # sidx=22, runidx=2 is the only session that does not contain all events for the contrast in dataset 1 (missing COs)
        # sidx=45, runidx=0 is the only session that does not contain all events for the contrast in dataset 2 (missing OEs)
        if all(ev in sub_run_desmat.columns for ev in contrast_cols):
            curr_timeseries = get_timeseries(timeseries_data[sidx][runidx])
            
            curr_timeseries, sub_run_desmat, skip_session = check_dataset2_censored_trs(sub_run_desmat, curr_timeseries, args)
                
            if not skip_session:
                sub_run_desmat, curr_timeseries = exclude_irrelevant_VTC_TRs(sub_run_desmat, curr_timeseries, args, break_onsets, break_offsets)

                des_mat_dict[sub].append(sub_run_desmat)    
                fitable_timeseries[sub].append(curr_timeseries)
        else:
            print(f'Skipping {sub} run {runidx} because it does not contain all events for the contrast')

In [None]:
# FIRST LEVEL MODELS
contrast_dfs = defaultdict(pd.DataFrame)
for sidx, sub in enumerate(des_mat_dict.keys()):
    if len(des_mat_dict[sub]):
        fitted_flm = fit_edge_flm(
            FirstLevelModel(t_r=args.t_r, signal_scaling=args.signal_scaling, noise_model=args.glm_noise_model, n_jobs=6),
            run_Ys=fitable_timeseries[sub],
            design_matrices=des_mat_dict[sub],
        )
        for con in contrasts:
            if sidx=='subNDARDW205DVZ' and con=='common_fail':  # this participant did not make any omission errors
                pass
            else:
                estimates = compute_edge_contrast(fitted_flm, con, stat_type='t', output_type='stat')['estimate']

                if args.use_rois and not args.replicate:
                    roi_estimates = (estimates[:, np.newaxis] * estimates[np.newaxis, :])[ridx, cidx]
                    estimates = np.sign(roi_estimates) * np.sqrt(abs(roi_estimates))

                contrast_dfs[con] = pd.concat([contrast_dfs[con], 
                                              pd.DataFrame(estimates).T
                                              ])

In [None]:
# SAVE FIRSTLEVEL RESULTS
results_str = f'model-{args.model}_contrast-%s_datatype-{"roi" if args.use_rois else "edge"}.csv'
for con in contrasts:
    contrast_dfs[con].to_csv(path.join(results_dir, results_str % con), index=False)