# **Within-Group SRM**

- Build functions
	- Preprocess the raw data
	- Build a SRM model
	- Train a within-group SRM model
	- Calculate the ISC and permutation
	- Calculate permutation, find the thresholds of p-value and implement Autoaq command

- Main process
	- Set selected features for each group
	- Get the pre-processed data
	- Train a SRM model and get reconstructed data
	- for two groups in high, mid and low
		- calculate the permutation of two groups
		- Find the thresholds(min corr difference of tow groups) for different p-values
		- Autoaq for each permutation data

In [2]:
import os
import h5py
import brainiak
import brainiak.funcalign.srm
import nibabel as nib
import numpy as np
from scipy import stats
from brainiak import io, isc
from nilearn import masking


In [None]:
DATA_PATH = '/data/neuro/LLS_audio/derivatives/analysis01/FUNC_reorg/'
MASK_PATH = '/data/neuro/LLS_audio/derivatives/analysis01/FUNC_reorg/'

SUBJ_NUM = 36
MASK_NAME = 'EPI_{}_avg_mask.nii.gz'.format(SUBJ_NUM)

LABELS = {
    'high': [1, 2, 4, 8, 11, 16, 19, 20, 23, 28],
    'mid': [3, 6, 7, 14, 21, 24, 32, 34, 35, 37],
    'low': [5, 9, 10, 12, 13, 15, 17, 18, 22, 25, 26, 27, 30, 31, 33, 36]
}
LABEL_MAP = {'high': 0, 'mid': 1, 'low': 2}


Preprocess the raw data

In [None]:
# pre-process all data from .nii.gz files


class Processor:
    def __init__(self):
        self.preprocessed_data_fn = 'subj_run_masked.h5'
        self.concatenated_data_fn = 'concatenated_masked.h5'
        self.save_dir_p = 'data/preprocess'
        self.preprocessed_data_p = os.path.join(self.save_dir_p, self.preprocessed_data_fn)
        self.concatenated_data_p = os.path.join(self.save_dir_p, self.concatenated_data_fn)

    def set_dir_path(self, p):
        """
        Set directory path for saving files.
        
        :param p: str
        """
        if not os.path.exists(p):
            os.mkdir(p)

        self.save_dir_p = p
        self.preprocessed_data_p = os.path.join(self.save_dir_p, self.preprocessed_data_fn)
        self.concatenated_data_p = os.path.join(self.save_dir_p, self.concatenated_data_fn)

    def pre_process(self, rts=973):
        """
        :param rts: 
        :return: 
        """
        if os.path.exists(self.concatenated_data_p):
            return

        mask_data = nib.load(os.path.join(MASK_PATH, MASK_NAME))
        mask = np.array(mask_data.dataobj)

        all_subj = []
        all_subj_label = []

        labels = LABELS
        label_map = LABEL_MAP

        for group_name, group in labels.items():
            print(group_name, group)
            for subj_i in sorted(group):
                one_run = np.array([])
                for run_i in [1, 2, 3, 4]:
                    subj_fn = 'sub{:03}.run{:02}.func.resampl.nii.gz'.format(subj_i, run_i)
                    subj_fp = os.path.join(DATA_PATH, subj_fn)

                    if not os.path.exists(subj_fp):
                        raise FileNotFoundError(subj_fp)

                    voxels = masking.apply_mask(subj_fp, mask_data)  # shape=(tr, v)

                    if len(one_run) == 0:
                        one_run = voxels
                    else:
                        one_run = np.concatenate((one_run, voxels))

                one_run = one_run.transpose()  # transposed to shape=(v, tr)

                if one_run.shape != (np.sum(mask > 0), rts):
                    raise ValueError('Shape is invalid: sub{:03} {}'.format(subj_i, one_run.shape))

                print('... succeed sub{:03}'.format(subj_i))
                all_subj.append(one_run)
                all_subj_label.append(label_map[group_name])

        fc = h5py.File(self.concatenated_data_p, 'w')
        fc.create_dataset('all_runs', data=np.array(all_subj))
        fc.create_dataset('labels', data=all_subj_label)
        fc.close()


processor = Processor()


def pre_process(p):
    print('Preprocessing...')
    processor.set_dir_path(p)
    processor.pre_process()


Build a SRM model

In [None]:
# build a SRM model


class SRMModel:
    def __init__(self):
        self.features = 50
        self.n_iter = 20
        self.model = None

    def optimize_features_num(self):
        return

    def train_srm(self, train_data, features=50, n_iter=20):
        """

        :param features:
        :param n_iter:
        :param train_data: shape = (m, v, tr)
        """
        self.features = features
        self.n_iter = n_iter
        self.model = brainiak.funcalign.srm.SRM(features=self.features, n_iter=self.n_iter)

        for subject in range(len(train_data)):
            train_data[subject] = stats.zscore(train_data[subject], axis=1, ddof=1)
            train_data[subject] = np.nan_to_num(train_data[subject])

        self.model.fit(train_data)

    def reconstruct(self, data):
        """

        :param data: shape = (m, v, tr)
        :return:
        """
        m, v, tr = data.shape

        # Transform the data into the shared space using the individual weight matrices
        shared = self.model.transform(data)  # shape = (m ,f, tr)

        # Zscore the transformed data
        for subject in range(m):
            shared[subject] = stats.zscore(shared[subject], axis=1, ddof=1)
            shared[subject] = np.nan_to_num(shared[subject])

        # Do the reconstruction on all individual participants and organize it for ISC

        signal_srm = np.zeros((m, v, tr))

        for subject in range(m):
            signal_srm[subject, :, :] = self.model.w_[subject].dot(shared[subject])
            signal_srm[subject] = np.nan_to_num(signal_srm[subject])

        return signal_srm


srm_model = SRMModel()


Train a within-group SRM model

In [None]:
# train a SRM model within each group


def train_srm_model(by, **kwargs):
    """
    Train a SRM model.
    :param by: str 'group' for within-group SRM or 'all for across-group SRM
    :param kwargs: 
    """
    if 'save_dir_p' in kwargs.keys():
        if not os.path.exists(kwargs['save_dir_p']):
            os.mkdir(kwargs['save_dir_p'])
    
    if by == 'group':
        train_srm_model_by_group(**kwargs)


def train_srm_model_by_group(features, save_dir_p=None):
    if not save_dir_p:
        save_dir_p = processor.save_dir_p

    if os.path.exists(os.path.join(save_dir_p, 'group_data.h5')):
        return

    print('Training SRM...')
    f = h5py.File(processor.concatenated_data_p, 'r')
    all_runs = f['all_runs'][:]
    all_runs_labels = f['labels'][:]

    f = h5py.File(os.path.join(save_dir_p, 'group_data.h5'), 'w')
    f.close()

    f = h5py.File(os.path.join(save_dir_p, 'group_data.h5'), 'r+')
    print('Saving reconstructed data...')
    for group in ['high', 'mid', 'low']:
        group_runs = all_runs[all_runs_labels == LABEL_MAP[group]]

        if isinstance(features, int):
            srm_model.train_srm(group_runs, features=features)
        elif isinstance(features, str):
            f_list = features.split('_')
            selected_f = int(f_list[LABEL_MAP[group]])
            srm_model.train_srm(group_runs, features=selected_f)

        group_runs = srm_model.reconstruct(group_runs)
        f.create_dataset(group, data=group_runs)
    f.close()


Calculate the ISC and permutation

In [None]:
# calculate the ISC and permutation for each group


def calculate_isc(data, **kwargs):
    corr = isc.isc(data, **kwargs)
    corr = np.nan_to_num(corr)
    return corr


def calculate_permutation_isc(all_subj_corr, labels=None, **kwargs):
    """
    Note: either one group or two group in your labels

    :param all_subj_corr:
    :param labels:
    :param kwargs:
    """
    n_permutations = 1000
    summary_statistic = 'mean'

    if 'summary_statistic' in kwargs.keys():
        summary_statistic = kwargs['summary_statistic']
    if 'n_permutations' in kwargs.keys():
        n_permutations = kwargs['n_permutations']

    observed, p, distribution = isc.permutation_isc(
        all_subj_corr,
        pairwise=False,
        group_assignment=labels,
        summary_statistic=summary_statistic,
        n_permutations=n_permutations,
        **kwargs
    )
    return observed, p, distribution


def calculate_iscs(group, summary_statistic=None, save_dir_p=None):
    if not save_dir_p:
        save_dir_p = 'data/preprocess'

    fn = '{}reconst_isc.h5'.format((summary_statistic + '_') if summary_statistic else '')
    fp = os.path.join(save_dir_p, fn)

    if not os.path.exists(fp):
        print('Create new .h5 file at', fp)
        read_type = 'w'
    else:
        print('Read existed .h5 file', fp)
        read_type = 'r+'
    f_c = h5py.File(fp, read_type)

    if group in f_c.keys():
        return

    print('Reading reconstructed data...')
    f_rg = h5py.File(os.path.join(save_dir_p, 'group_data.h5'), 'r')
    group_data = f_rg[group][:]
    f_rg.close()

    print('Calculating ISC...')
    group_corr = calculate_isc(group_data.transpose(), summary_statistic=summary_statistic)

    print('Saving ISC...')

    f_c.create_dataset(group, data=group_corr)
    f_c.close()


Calculate permutation, find the thresholds of p-value and implement Autoaq command

In [None]:
def get_two_group_permutation(ftype, save_dir_p, is_save=True):
    """
    Calculate the permutation of two groups.
    
    :param ftype: ['hl', 'hm', 'ml', 'lh', 'mh', 'lm']
    :param save_dir_p: 
    :param is_save: 
    """
    for group in ['low', 'mid', 'high']:
        print('\nGroup', group)
        calculate_iscs(group, save_dir_p=save_dir_p)

    print('Reading reconstructed isc...')
    fp = os.path.join(save_dir_p, 'reconst_isc.h5')
    f_c = h5py.File(fp, 'r')

    grp_name = {
        'h': 'high',
        'l': 'low',
        'm': 'mid'
    }
    corr_type_0, corr_type_1 = ftype[0], ftype[1]
    corr_0 = f_c[grp_name[corr_type_0]][:]
    corr_1 = f_c[grp_name[corr_type_1]][:]
    f_c.close()

    _corr = np.concatenate((corr_0, corr_1), axis=0)
    _labels = [0] * len(corr_0) + [1] * len(corr_1)
    observed, p, distribution = calculate_permutation_isc(_corr, _labels)
    if is_save:
        with h5py.File('output/{}_permutation.h5'.format(ftype), 'w') as f:
            f.create_dataset('observed', data=observed)
            f.create_dataset('p', data=p)
            f.create_dataset('distribution', data=distribution)
            f.close()


def autoaq_2grp(ftype):
    """
    Autoaq command line.
    
    :param ftype: ['hl', 'hm', 'ml', 'lh', 'mh', 'lm']
    """

    def get_isc_image(mask_p, corr, img_name=None, is_save=True):
        print('Writing ISC map to file...')

        # Map the ISC data for the participant into brain space
        brain_template = nib.load(mask_p)
        mask_image = io.load_boolean_mask(mask_p)
        coords = np.where(mask_image)
        isc_vol = np.zeros(brain_template.shape)
        isc_vol[coords] = corr

        # make a nii image of the isc map
        isc_image = nib.Nifti1Image(isc_vol, brain_template.affine, brain_template.header)

        if is_save:
            nib.save(isc_image, 'output/{}_isc.nii.gz'.format(img_name))

    print('Reading permutation result...')
    fp = os.path.join('output', '{}_permutation.h5'.format(ftype))
    f_c = h5py.File(fp, 'r')

    mask_path = os.path.join(MASK_PATH, MASK_NAME)

    print('Saving image...')
    observed = f_c['observed'][:]
    get_isc_image(mask_p=mask_path, corr=observed, img_name='observed_{}'.format(ftype))

    print('Autoaq...')
    observed_fp = os.path.join('output', 'observed_{}_isc.nii.gz'.format(ftype))
    result_fp = os.path.join('output', 'subj36_srm_isc_{}_Talairach.txt'.format(ftype))
    os.system('autoaq -i {} -a "Talairach Daemon Labels" '
              '-t 0.1 -u -p -o {}'.format(observed_fp, result_fp))


def analysis_permutation_2grp(ftype):
    """
    Get the threshold of ISC difference.
    
    :param ftype: ['hl', 'hm', 'ml', 'lh', 'mh', 'lm']
    """
    print('Reading permutation result...')
    fp = 'output/{}_permutation.h5'.format(ftype)

    f = h5py.File(fp, 'r')
    p_value = f['p'][:]
    observed = f['observed'][:]

    def write_log(p, text):
        with open(os.path.join(p, 'permutation_threshold.txt'), 'a') as _f:
            _f.write(text)

    print('Find threshold of ftype:', ftype)
    for sig in [0.05, 0.01, 0.005,.001]:
        threshold = np.min(np.abs(observed[np.where(p_value < sig)]))
        content = '{} threshold of ISC difference is {:.2f}, p<{}'.format(ftype, threshold, sig)
        print(content)
        write_log('output', content + '\n')


Main process

In [None]:
# Set selected features for different groups
s_feature = '40_55_30'  # high, mid, low
train_type = 'group'
save_path = 'data/preprocess/subj_{}_feature_{}/by_{}'.format(s_feature, SUBJ_NUM, train_type)

# Pre-process the raw data
pre_process(save_path)

# Train a within-group SRM model
train_type = 'group'
train_srm_model(by=train_type, save_dir_p=save_path, features=s_feature)

for ft in ['hl', 'hm', 'ml', 'lh', 'mh', 'lm']:
    # Calculate the permutation of two groups
    get_two_group_permutation(ft, save_path)

    # Find the thresholds(min corr difference of tow groups) for different p-values
    analysis_permutation_2grp(ft)

    # Autoaq bash command
    autoaq_2grp(ft)
