# Streamline Aligned Bundle Reliability

In [1]:
import cloudknot as ck

ck.set_region('us-west-2')

In [2]:
def profile_reliability(ak, sk):        
    import os
    import os.path as op
    
    import pickle
    
    import s3fs
    
    from AFQ import api
    import AFQ.data as afd
    
    from fastdtw import fastdtw
    
    import nibabel as nib
    
    from dipy.io.streamline import load_tractogram
    from dipy.stats.analysis import afq_profile, gaussian_weights
    from dipy.tracking.streamline import set_number_of_points, values_from_volume
    
    def get_hcp_afq(dataset_name):
        afq = api.AFQ(
            bids_path=op.join(afd.afq_home, dataset_name),
            dmriprep='dmriprep'
        )

        return afq
    
    def get_subject_iloc(afq, subject):
        iloc = afq.data_frame.index[afq.data_frame['subject'] == subject][0]

        return iloc

    def get_subject_scalar_data(afq, subject, scalar):
        iloc = get_subject_iloc(afq, subject)

        scalar_filename = afq._get_fname(
            afq.data_frame.iloc[iloc],
            f'_model-{scalar}.nii.gz'
        )

        scalar_data = nib.load(scalar_filename).get_fdata()

        return scalar_data

    def get_subject_bundle_tractogram(afq, subject, bundle_name):
        iloc = get_subject_iloc(afq, subject)

        results_dir = afq.data_frame.iloc[iloc]['results_dir']

        fname = op.split(
            afq._get_fname(
                afq.data_frame.iloc[iloc],
                f'-{bundle_name}'
                f'_tractography.trk',
                include_track=True,
                include_seg=True
            )
        )

        tractogram_filename = op.join(results_dir, 'clean_bundles', fname[1])

        tractogram = load_tractogram(tractogram_filename, 'same')

        return tractogram

    def get_subject_bundle_profile(afq, subject, scalar_data, bundle_name):
        tractogram = get_subject_bundle_tractogram(afq, subject, bundle_name)

        if len(tractogram.streamlines) == 0:
            return np.zeros(n_points)

        profile = afq_profile(
            scalar_data,
            tractogram.streamlines,
            tractogram.affine,
            weights=gaussian_weights(tractogram.streamlines)
        )

        return profile

    def get_bundle_profiles(afq):
        bundle_profiles = {}

        for subject in hcp_subjects:
            bundle_profiles[subject] = {}

            for scalar in afq.scalars:
                bundle_profiles[subject][scalar] = {}
                scalar_data = get_subject_scalar_data(afq, subject, scalar)

                for bundle_name in bundle_names:
                    bundle_profiles[subject][scalar][bundle_name] = get_subject_bundle_profile(afq, subject, scalar_data, bundle_name)

        return bundle_profiles

    def get_test_retest_correlations():
        correlations = {}

        test_bundle_profiles = get_bundle_profiles(hcp_test_afq)
        retest_bundle_profiles = get_bundle_profiles(hcp_retest_afq)

        for scalar in hcp_retest_afq.scalars:
            correlations[scalar] = {}
            for subject in hcp_subjects:
                correlations[scalar][subject] = {}
                for bundle_name in bundle_names:
                    test_profile = test_bundle_profiles[subject][scalar][bundle_name]
                    retest_profile = retest_bundle_profiles[subject][scalar][bundle_name]

                    test_retest_corr_matrix = pd.DataFrame(zip(*[test_profile, retest_profile]), columns=['test', 'retest']).corr()

                    # select only the upper triangle off diagonals of the correlation matrix
                    test_retest_corr = test_retest_corr_matrix.where(np.triu(np.ones(test_retest_corr_matrix.shape), 1).astype(np.bool)).stack()

                    if len(test_retest_corr) == 1:
                        correlations[scalar][subject][bundle_name] = test_retest_corr[0]
                    else:
                        correlations[scalar][subject][bundle_name] = 0

        return correlations
    
    def get_subject_mean_warped_bundle_profile(afq, subject, scalar_data, bundle_name):
        tractogram = get_subject_bundle_tractogram(afq, subject, bundle_name)

        if len(tractogram.streamlines) == 0:
            return np.zeros(n_points)

        fgarray = set_number_of_points(tractogram.streamlines, n_points)

        values = np.array(values_from_volume(scalar_data, fgarray, tractogram.affine))
        mean_values = np.mean(values, axis=0)

        dtw_values = []

        for value in values:
            dist, path = fastdtw(value, mean_values)
            path = np.array(path)
            dtw_values.append(value[np.append(path[np.where(path[:,1][:-1] != path[:,1][1:]),0][0], len(values.T)-1)])

        dtw_values = np.array(dtw_values)

        dtw_mean_values = np.mean(dtw_values, axis=0)

        return dtw_mean_values

    def get_mean_warped_bundle_profiles(afq):
        mean_warped_bundle_profiles = {}

        for subject in hcp_subjects:
            mean_warped_bundle_profiles[subject] = {}
            iloc = get_subject_iloc(afq, subject)

            for scalar in afq.scalars:
                mean_warped_bundle_profiles[subject][scalar] = {}
                scalar_data = get_subject_scalar_data(afq, subject, scalar)

                for bundle_name in bundle_names:
                    dtw_mean_values = get_subject_mean_warped_bundle_profile(afq, subject, scalar_data, bundle_name)
                    mean_warped_bundle_profiles[subject][scalar][bundle_name] = dtw_mean_values

        return mean_warped_bundle_profiles

    def get_test_retest_warped_correlations():
        correlations = {}

        test_bundle_profiles = get_mean_warped_bundle_profiles(hcp_test_afq)
        retest_bundle_profiles = get_mean_warped_bundle_profiles(hcp_retest_afq)

        for scalar in hcp_retest_afq.scalars:
            correlations[scalar] = {}
            for subject in hcp_subjects:
                correlations[scalar][subject] = {}
                for bundle_name in bundle_names:
                    test_profile = test_bundle_profiles[subject][scalar][bundle_name]
                    retest_profile = retest_bundle_profiles[subject][scalar][bundle_name]

                    test_retest_corr_matrix = pd.DataFrame(zip(*[test_profile, retest_profile]), columns=['test', 'retest']).corr()

                    # select only the upper triangle off diagonals of the correlation matrix
                    test_retest_corr = test_retest_corr_matrix.where(np.triu(np.ones(test_retest_corr_matrix.shape), 1).astype(np.bool)).stack()

                    if len(test_retest_corr) == 1:
                        correlations[scalar][subject][bundle_name] = test_retest_corr[0]
                    else:
                        correlations[scalar][subject][bundle_name] = 0

        return correlations
    
    fs = s3fs.S3FileSystem(anon=False)
    
    hcp_subjects = [
        '103818', '105923', '111312', '114823', '115320', '122317', '125525', '130518', '135528', '137128',
        '139839', '143325', '144226', '146129', '149337', '149741', '151526', '158035', '169343', '172332',
        '175439', '177746', '185442', '187547', '192439', '194140', '195041', '200109', '200614', '204521',
        '250427', '287248', '341834', '433839', '562345', '599671', '601127', '627549', '660951', '662551',
        '783462', '859671', '861456', '877168', '917255'
    ]
    
    hcp_subjects = hcp_subjects[:2] # ['103818', '105923']
    
    afd.fetch_hcp(hcp_subjects,
                  profile_name=False,
                  aws_access_key_id=ak,
                  aws_secret_access_key=sk)

    afd.fetch_hcp(hcp_subjects, 
                  study='HCP_Retest', 
                  profile_name=False,
                  aws_access_key_id=ak,
                  aws_secret_access_key=sk)
    
    for subject in hcp_subjects:
        # HCP test
        subject_test_dir = op.join(afd.afq_home, 'HCP_1200', 'derivatives', 'afq', f'sub-{subject}')
        os.makedirs(subject_test_dir)
        fs.get(f'profile-hcp-west/hcp_reliability/single_shell/hcp_1200_afq/sub-{subject}/', subject_test_dir, recursive=True)


        # HCP retest
        subject_retest_dir = op.join(afd.afq_home, 'HCP_Retest', 'derivatives', 'afq', f'sub-{subject}')
        os.makedirs(subject_retest_dir)
        fs.get(f'profile-hcp-west/hcp_reliability/single_shell/hcp_retest_afq/sub-{subject}/', subject_retest_dir, recursive=True)
    
    # debugging why can't find data_description.json
#     for root, dirs, files in os.walk(afd.afq_home):
#         level = root.replace(afd.afq_home, '').count(os.sep)
#         indent = ' ' * 4 * (level)
#         print('{}{}/'.format(indent, os.path.basename(root)))
#         subindent = ' ' * 4 * (level + 1)
#         for f in files:
#             print('{}{}'.format(subindent, f))
    
    hcp_test_afq = get_hcp_afq('HCP_1200')
    hcp_retest_afq = get_hcp_afq('HCP_Retest')
    
    bundle_names = [*hcp_retest_afq.bundle_dict]
    n_points = 100
    
    corr = get_test_retest_correlations()
    
    with open('corr.pkl', 'wb') as handle:
        pickle.dump(corr, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    warped_corr = get_test_retest_warped_correlations()
    
    with open('warped_corr.pkl', 'wb') as handle:
        pickle.dump(warped_corr, handle, protocol=pickle.HIGHEST_PROTOCOL)

    fs.put('corr.pkl', 'warp-alignment/corr.pkl')
    fs.put('warped_corr.pkl', 'warp-alignment/warped_corr.pkl')

In [3]:
from datetime import datetime

knot = ck.Knot(
    name="pro_rel-" + datetime.now().isoformat()[:-7].replace(":","-"),
    pars_policies=("AmazonS3FullAccess",),
    base_image='python:3.8',
    func=profile_reliability,
    image_github_installs="https://github.com/yeatmanlab/pyAFQ.git",
#     image_github_installs="https://github.com/bloomdt-uw/pyAFQ.git@enh-565-subsample-streamlines",
    memory=32000,  # in MiB
    volume_size=50,  # in GiB
    bid_percentage=105
)

get aws credentials to access HCP data

In [4]:
import configparser
import os.path as op

cp = configparser.ConfigParser()
cp.read_file(open(op.join(op.expanduser('~'), '.aws', 'credentials')))
cp.sections()
ak = cp.get('hcp', 'AWS_ACCESS_KEY_ID')
sk = cp.get('hcp', 'AWS_SECRET_ACCESS_KEY')

In [5]:
result_futures = knot.map([(ak, sk)], starmap=True, job_type="independent")

In [6]:
knot.view_jobs()

Job ID              Name                        Status   
---------------------------------------------------------
059884e5-41e8-4d62-b37d-f24662ebdab5        pro-rel-2020-12-14T22-51-15-0        SUBMITTED


In [7]:
knot.clobber(clobber_pars=True, clobber_repo=True, clobber_image=True)