In [1]:
from nipype.interfaces import freesurfer
from nipype.interfaces import fsl
import nibabel as nib
import os
import numpy as np
from config import root
from tools.helpers import save_nii, mni2index, index2mni
from scipy import ndimage

In [2]:
from tools.roi_selection import Cluster, Sphere
from tools.stats import fdr

In [3]:
output = os.path.join(root, "data", "output")
fs_dir = os.path.join(root, "data", "in_analysis", "fs_subjects")
standard = os.path.join(root, "data", "in_analysis", "nii", "standard")
mni_path = os.path.join(standard, "MNI152_T1_2mm.nii.gz")
mni = nib.load(mni_path).get_data()
mask_path = os.path.join(root, "data", "in_analysis", "nii", "standard", "segmentation_seg.nii.gz")

# Level 3 contrasts
l3_loc = os.path.join(output, "loc_mixed_fx")
l3_out = os.path.join(output, "pop_decay_mixed_fx")
l2_out = os.path.join(output, "pop_decay_L2_fixedfx_warped")
contrasts = {
    "right_upper": os.path.join(l3_out, "0", "_fixedflameo0", "zstat1.nii.gz"), 
    "left_upper": os.path.join(l3_out, "1", "_fixedflameo0", "zstat1.nii.gz"), 
    "left_lower": os.path.join(l3_out, "2", "_fixedflameo0", "zstat1.nii.gz"), 
    "right_lower": os.path.join(l3_out, "3", "_fixedflameo0", "zstat1.nii.gz"), 
    "loc_priming": os.path.join(l3_out, "4", "_fixedflameo0", "zstat1.nii.gz"), 
    "loc_lag1_priming": os.path.join(l3_out, "5", "_fixedflameo0", "zstat1.nii.gz"), 
    "loc_lag2_priming": os.path.join(l3_out, "6", "_fixedflameo0", "zstat1.nii.gz"), 
    "clr_priming": os.path.join(l3_out, "7", "_fixedflameo0", "zstat1.nii.gz"), 
    "clr_lag1_priming": os.path.join(l3_out, "8", "_fixedflameo0", "zstat1.nii.gz"), 
    "clr_lag2_priming": os.path.join(l3_out, "9", "_fixedflameo0", "zstat1.nii.gz"), 
    "baseline": os.path.join(l3_out, "10", "_fixedflameo0", "zstat1.nii.gz"), 
    "loc_0": os.path.join(l3_loc, "0", "_fixedflameo0", "zstat1.nii.gz"),
    "loc_1": os.path.join(l3_loc, "1", "_fixedflameo0", "zstat1.nii.gz"),
    "loc_2": os.path.join(l3_loc, "2", "_fixedflameo0", "zstat1.nii.gz"),
    "loc_3": os.path.join(l3_loc, "3", "_fixedflameo0", "zstat1.nii.gz")
}

registration = os.path.join(root, "data/output/register_to_standard/sub_%03d")
inplane = os.path.join(root, "data/in_analysis/nii/sub_%03d/ses_000/anatomy/inplane.nii.gz")
premat = os.path.join(registration, "inplane_brain_bbreg_sub_%03d.mat")
clr_roi_sub_path = os.path.join(root, "data", "output", "postprocessing", "rois", "sub_%03d", "clr")
loc_roi_sub_path = os.path.join(root, "data", "output", "postprocessing", "rois", "sub_%03d", "loc")

In [4]:
# color
table3 = {
    "l-ips"    : (-26, -62, 48),
    "l-fef"    : (-34, 6,   52),
    "l-lat-occ" : (-36, -72, -6),
    "l-fg"     : (-44, -56, -16),
    "r-ips"    : (40, -48, 58),
    "r-acc"    : (-2, -14, 52),
#     "r-mfg"    : (32, 40, 26),
    "r-occ"    : (12, -88, -8),
    "r-fef"    : (32, -2, 50),
}

# location
table2 = {
    "lips"    : (-30, -60, 40),
    "lfef"    : (-32, -12, 54),
#     "lmfg"    : (-34, 36,  18),
    "lpc"     : (-8, -70,  8),
    "rfef"    : (28, -8, 56),
#     "rmfg"    : (28, 26, 22),
#     "rifg"    : (44, -18, -4),
#     "racc"    : (2, 22, 36),
    "rap"     : (34, -34, 60),
    "rip"     : (48, -42, 40),
    "rpc"     : (14, -68, 20),
    "rips"    : (24, -66, 48),
}

In [5]:
mask_ctx = nib.load(mask_path).get_data() == 2
base_mask = ndimage.filters.convolve(mask_ctx, np.ones((2,2,2)))
zstat = nib.load(contrasts["baseline"]).get_data()
thr, zthr, pvals, thrline, pcor, padj = fdr(zstat, q=.1, mask=None)
base_mask *= nib.load(contrasts["baseline"]).get_data() > zthr

In [6]:
mni_path = os.path.join(standard, "MNI152_T1_2mm.nii.gz")
sph = Sphere(radius=8)
zero_mask = np.zeros(zstat.shape, dtype=int)

def slice_sphere(x, y, z, r):
    # Create sphere selection in MNI
    slices = [
        slice(x-r, x+r+1),
        slice(y-r, y+r+1),
        slice(z-r, z+r+1)
    ]    
    return slices

In [7]:
zstat = nib.load(contrasts["clr_lag2_priming"])
zstat_data = zstat.get_data()

In [8]:
rois = {
    "clr": [],
    "loc": []
}

In [9]:
for cond, table in zip(["loc", "clr"], [table2, table3]):
    for i, (name, coords) in enumerate(table.items()):

        # To integers
        x, y, z = mni2index(coords)

        # Create mask
        mask = zero_mask.copy()
        mask[slice_sphere(x, y, z, sph.radius)] = sph.box.astype(int)
#         mask *= base_mask

        # Create matrix with indices
        indices = np.indices(mask.shape)

        # Select minimum value in mask
        min_val = zstat_data[mask.astype(bool)].min()

        # Find index of peak
        peakx = indices[0, zstat_data == min_val][0]
        peaky = indices[1, zstat_data == min_val][0]
        peakz = indices[2, zstat_data == min_val][0]

        roi = zero_mask.copy()
        small_sphere = Sphere(radius=2)
        roi[slice_sphere(peakx, peaky, peakz, small_sphere.radius)] = small_sphere.box.astype(int)

        out_dir = os.path.join(
            root, "data", "output", "postprocessing", "rois", cond)
        
        if not os.path.exists(out_dir):
            os.mkdir(out_dir)
            
        out_file = "%s_peak.nii.gz" % name
        out_path = os.path.join(out_dir, out_file)
        
        rois[cond].append(out_path)

        img = nib.Nifti1Image(roi, zstat.affine)
        nib.save(img, out_path)

In [10]:
for sub in range(1, 7):    
    clr_path = clr_roi_sub_path % sub
    loc_path = loc_roi_sub_path % sub
    
    if not os.path.exists(loc_path):
        os.makedirs(loc_path)
        
    if not os.path.exists(clr_path):
        os.makedirs(clr_path)
    
    inv_path = os.path.join(registration, "inv_warp.nii.gz") % sub
    if not os.path.exists(inv_path):
        inv_warp = fsl.InvWarp(
            warp=os.path.join(registration, "orig_field.nii.gz") % sub,
            inverse_warp=inv_path,
            reference=mni_path,
            output_type='NIFTI_GZ'
        )
        inv_warp.run()

    postmat = os.path.join(registration, "postmat.mat") % sub
    np.savetxt(
        postmat,
        np.linalg.inv(np.loadtxt(premat % (sub, sub))),
        fmt="%.8f"
    )
    
    for cond in ["clr", "loc"]:
        
        rois_dir = os.path.join(root, "data", "output", "postprocessing", "rois")
        rois_files = os.listdir(os.path.join(rois_dir, cond))
        for roi in rois_files:
            roi_path = os.path.join(rois_dir, cond, roi)
            
            sub_roi_path = os.path.join(rois_dir, "sub_%03d", cond, roi) % sub
    
            warp = fsl.ApplyWarp(field_file=inv_path,
                                 ref_file=inplane % sub,
                                 in_file=roi_path,
                                 out_file=sub_roi_path,
                                 postmat=postmat,
                                 interp='nn')

            warp.run()