In [1]:
import time

In [2]:
tic = time.time()

In [3]:
import configparser
import warnings
import os.path as op
import numpy as np
import nibabel as nib
import dipy.data as dpd
import dipy.tracking.utils as dtu
import dipy.tracking.streamline as dts
from dipy.io.streamline import save_tractogram, load_tractogram
from dipy.stats.analysis import afq_profile, gaussian_weights
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.io.stateful_tractogram import Space
import dipy.core.gradients as dpg
from dipy.reconst import dti
from dipy.reconst import csdeconv as csd

import AFQ.data as afd
import AFQ.tractography as aft
import AFQ.registration as reg
import AFQ.segmentation as seg
import AFQ.api as api


import s3fs

import logging
logging.basicConfig(level=logging.INFO)



In [4]:
CP = configparser.ConfigParser()
CP.read_file(open(op.join(op.expanduser('~'), '.aws', 'credentials')))
hcp_ak = CP.get('hcp', 'AWS_ACCESS_KEY_ID')
hcp_sk = CP.get('hcp', 'AWS_SECRET_ACCESS_KEY')

In [5]:
fs = s3fs.S3FileSystem(key=hcp_ak, secret=hcp_sk)

In [6]:
CP = configparser.ConfigParser()
CP.read_file(open(op.join(op.expanduser('~'), '.aws', 'credentials')))
ak = CP.get('default', 'AWS_ACCESS_KEY_ID')
sk = CP.get('default', 'AWS_SECRET_ACCESS_KEY')

In [7]:
my_fs =  s3fs.S3FileSystem(key=ak, secret=sk)

In [8]:
subject = 100307

In [9]:
# dwi_fname = f'hcp-openaccess/HCP_1200/{subject}/T1w/Diffusion/data.nii.gz'
# dwi_img = afd.s3fs_nifti_read(dwi_fname, fs=fs)

In [10]:
# with fs.open(f'hcp-openaccess/HCP_1200/{subject}/T1w/Diffusion/bvals') as ff:
#     bvals = np.loadtxt(ff)
    
# with fs.open(f'hcp-openaccess/HCP_1200/{subject}/T1w/Diffusion/bvecs') as ff:
#     bvecs = np.loadtxt(ff)

In [11]:
# gtab = dpg.gradient_table(bvals, bvecs, b0_threshold=50)
# mapping = reg.syn_register_dwi(dwi_img, gtab)[1]

In [12]:
# reg.write_mapping(mapping, 'mapping.nii.gz')

In [13]:
# mapping_img = nib.load('mapping.nii.gz')

In [14]:
# afd.s3fs_nifti_write(mapping_img, 
#                      f'hcp.pangeo.experiments/{subject}/mapping.nii.gz', 
#                      fs=my_fs)

In [15]:
def segment_bundle(params):
    log = logging.getLogger(__name__)
    bundle_name, sl_idx = params
    with fs.open(f'hcp-openaccess/HCP_1200/{subject}/T1w/Diffusion/bvals') as ff:
        bvals = np.loadtxt(ff)
        np.savetxt('bvals', bvals)

    with fs.open(f'hcp-openaccess/HCP_1200/{subject}/T1w/Diffusion/bvecs') as ff:
        bvecs = np.loadtxt(ff)
        np.savetxt('bvecs', bvecs)

    log.info("Getting DWI data")
    dwi_fname = f'hcp-openaccess/HCP_1200/{subject}/T1w/Diffusion/data.nii.gz'
    dwi_img = afd.s3fs_nifti_read(dwi_fname, fs=fs)
    log.info("Saving DWI data")
    nib.save(dwi_img, 'data.nii.gz')
    
    log.info("Getting mapping")
    mapping_fname = f'hcp.pangeo.experiments/{subject}/mapping.nii.gz'
    mapping_img = afd.s3fs_nifti_read(mapping_fname, fs=my_fs)
    reg_template =  dpd.read_mni_template()
    mapping = reg.read_mapping(mapping_img, dwi_img, reg_template)
    
    log.info("Getting bundle dict")
    bundle_dict = api.make_bundle_dict(bundle_names=[bundle_name])
    AFQ = seg.Segmentation()
    
    log.info("Getting streamlines")
    sl_file = f'hcp.pangeo.experiments/{subject}/sl-{sl_idx:03d}.trk'
    my_fs.download(sl_file, 'tmp.trk')
    tg = load_tractogram('./tmp.trk', dwi_img)
    streamlines = tg.streamlines

    streamlines = dts.Streamlines(
        dtu.transform_tracking_output(streamlines,
                                  np.linalg.inv(dwi_img.affine)))
    log.info("Segmenting")
    fiber_groups = AFQ.segment(bundle_dict, 
                               streamlines, 
                               fdata='data.nii.gz', 
                               fbval='bvals', 
                               fbvec='bvecs',
                               mapping=mapping,
                               b0_threshold=50)

    log.info("Saving out results and uploading")
    for kk in fiber_groups:
        print(kk, len(fiber_groups[kk]))

        sft = StatefulTractogram(
            dtu.transform_tracking_output(fiber_groups[kk], dwi_img.affine),
            dwi_img, 
            Space.RASMM)

        save_tractogram(sft, f'./{kk}-{sl_idx}_afq.trk',
                        bbox_valid_check=False)

        my_fs.upload(f'./{kk}-{sl_idx}_afq.trk', 
                     f'hcp.pangeo.experiments/{subject}/{kk}-{sl_idx}_afq.trk')

In [None]:
segment_bundle(('CST', 0))

INFO:__main__:Getting DWI data


In [None]:
from dask.distributed import Client
from dask_kubernetes import KubeCluster

In [None]:
n_workers = 40

In [None]:
cluster = KubeCluster(n_workers=n_workers)
cluster

In [None]:
client = Client(cluster)
client