# Subsegmentation of ALIC tracts

This notebook is used for rapid prototyping of the ALIC segmentation pipeline. Once complete and relatively stable, it will be converted to a python script.

In [None]:
%matplotlib inline

In [None]:
import os
import sys
sys.path.append('wma_pyTools')

# import wmaPyTools.roiTools

In [None]:
import os
import sys
import itertools
import numpy as np
from pathlib import Path
import nibabel as nib
#import random

#make sure that wma_pyTools is right in the working directory, or that
#the package can otherwise be imported effectively
#sys.path.append('wma_pyTools')
startDir=Path(os.getcwd())
import pandas as pd
import wmaPyTools.roiTools
import wmaPyTools.analysisTools
import wmaPyTools.segmentationTools
import wmaPyTools.streamlineTools
import wmaPyTools.visTools

#dipy
from dipy.tracking.utils import reduce_labels
from dipy.tracking import utils
import dipy.io.streamline
from dipy.tracking.utils import density_map

In [None]:
targetLabels={'left':[1002,1026,1012,1020,1028,1003,1014,1019,1027],
              'right':[2002,2026,2012,2020,2028,2003,2014,2019,2027]}
spineLabels = {'left': [28, 16, 10], 
               'right': [16, 60, 49]}

#paths to input data
data_dir = startDir /'indata'
track_files = {
    'left': [data_dir / 'combined_aLIC_left.tck',],
    'right': [data_dir / 'combined_aLIC_right.tck',]}
parcellationPath=data_dir / 'aparc+aseg.nii.gz'
refT1Path=data_dir / 'T1w_acpc_dc_restore.nii.gz'

# Freesurfer lookup table, e.g. https://surfer.nmr.mgh.harvard.edu/fswiki/FsTutorial/AnatomicalROI/FreeSurferColorLUT
lutPath=data_dir / 'FreesurferLookup.csv'

#tckPath=Path('/home/naxos2-raid25/sreta001/DBS_for_sreta001/DBS-OCD/OCD004/Code/app-track_aLIC_harelpreproc/output/track.tck')
saveFigDir=startDir / 'output'


In [None]:
# sanity check inputs
to_check = [ parcellationPath, refT1Path, lutPath]
for side in ['left', 'right']:
    for tck in track_files[side]:
        to_check.append(tck)
        
for i in to_check:
    print(i)
    assert(i.is_file())

In [None]:
# define functions to generate tck for target

def get_streams_matching_target(streams, atlas, target):
    target_mask=wmaPyTools.roiTools.multiROIrequestToMask(atlas,target)
    # return boolean mask for stream selection
    return wmaPyTools.segmentationTools.segmentTractMultiROI(streams, 
                    [target_mask,], 
                    [True,], 
                    ['either_end',]) 
    
def save_density_map(streams, ref_img, out_file):
    density=utils.density_map(streams, ref_img.affine, ref_img.shape)
    densityNifti = nib.nifti1.Nifti1Image(density, ref_img.affine, ref_img.header)
    nib.save(densityNifti, out_file)
    
def save_streams_matching_target(streams, atlas, lookupTable, target, out_file):
    strTarget = lookupTable.loc[target, 'LabelName:']
    print('target label is: %s (%s)' % (target, strTarget))
    #out_file = Path(save_dir) / ('track_%04d_%s' % (target,strTarget)) #no file extension yet, add it later
    print(out_file)
    # get boolean vector of matching streams
    targetBool = get_streams_matching_target(streams, atlas, target)
    streams = streams[targetBool]
    
    #dipy quickbundles
    streams = streams[bundle(streams)]
    
    #save *.tck tractogram
    wmaPyTools.streamlineTools.stubbornSaveTractogram(streams,
        savePath=str(out_file.with_suffix('.tck')))
    # save nifti density map
    save_density_map(streams, atlas, out_file.with_suffix('.nii.gz'))
    return targetBool
    
#targetBool = save_streams_matching_target(streams,inflatedAtlas, lookupTable, iTarget, saveFigDir)
#wmaPyTools.streamlineTools.stubbornSaveTractogram(streams[targetBool], 
#    savePath=str(saveFigDir / '1002_test.tck' )

In [None]:
#apply the initial culling, to remove extraneous streamlines 
#first requires doing a DIPY quickbundling
def bundle(streams):
    print("DIPY quickbundle")
    clusters=wmaPyTools.streamlineTools.quickbundlesClusters(streams, thresholds = [30,20,10], nb_pts=100)

    #use those clusters to identify the streamlines to be culled
    print("identify streamlines to remove")
    survivingStreamsIndices, culledStreamIndicies=wmaPyTools.streamlineTools.cullViaClusters(clusters,streams,3)
    #convert survivingStreamsIndicies into a bool vec
    survivingStreamsBoolVec=np.zeros(len(streams),dtype=bool)
    survivingStreamsBoolVec[survivingStreamsIndices]=True
    
    print('%d of %d streams survived' % (len(survivingStreamsIndices), len(survivingStreamsBoolVec)))
    
    return survivingStreamsBoolVec

In [None]:
# load atlas-baced segmentation (Dan calls it a parcellation)
parcellaton=nib.load(parcellationPath)

In [None]:
# load T1 anatomical image
refT1=nib.load(refT1Path)

In [None]:
# load Freesurfer labels
lookupTable=pd.read_csv(lutPath,index_col='#No.')

In [None]:
#perform inflate & deIsland of input parcellation
inflated_atlas_file = saveFigDir / Path(Path(parcellationPath.stem).stem + '_inflated').with_suffix('.nii.gz')
print(inflated_atlas_file)
inflatedAtlas,deIslandReport,inflationReport= wmaPyTools.roiTools.preProcParc(parcellaton,deIslandBool=True,inflateIter=2,retainOrigBorders=False,maintainIslandsLabels=None,erodeLabels=[2,41])    
nib.save(inflatedAtlas,filename=inflated_atlas_file)

In [None]:
#Main cell, do all the hard work

for iSide in ['left', 'right']:
    for track_file in track_files[iSide]:        
        # load & orient streamlines
        
        tck_oriented_file = saveFigDir / Path(track_file.stem + '_oriented').with_suffix('.tck')
        if tck_oriented_file.exists():
            print('oriented tck already exists. loading %s' % tck_oriented_file)
            tckIn=nib.streamlines.load(tck_oriented_file)
            streams = tckIn.streamlines
        else:
            print('Load tck %s' % track_file)
            tckIn=nib.streamlines.load(track_file)
            print("orienting streamlines")
            streams=wmaPyTools.streamlineTools.orientAllStreamlines(tckIn.streamlines)
            # do quickbundles (never mind, takes too long)
            #streams = streams[bundle(streams)]
            # save oriented + bundled streams
            print('saving oriented tck %s' % tck_oriented_file)
            wmaPyTools.streamlineTools.stubbornSaveTractogram(streams,savePath=str(tck_oriented_file))
        
        parent_density_file = saveFigDir / Path(track_file.stem).with_suffix('.nii.gz')
        print('saving density map %s' % parent_density_file)
        save_density_map(streams, inflatedAtlas, parent_density_file)
        
        for iTarget in targetLabels[iSide]:
            targetStr = lookupTable.loc[iTarget, 'LabelName:']
            out_file = saveFigDir / ('%s_%04d_%s' % (track_file.stem, iTarget, targetStr))
            print('Starting processing for %s' % out_file.stem)
            
            # subsegment the streams and save the resulting density map and tck tractogram
            targetBool = save_streams_matching_target(streams, inflatedAtlas, lookupTable, iTarget, out_file)

            

## Appendix

In [None]:
# generate combined niftis which are the sum of all sides
for iSide in ['left', 'right']:
    for iTarget in targetLabels[iSide]:
        combined_img = np.zeros(refT1.shape)
        for track_file in track_files[iSide]:  # iterate over inferior and superior
            targetStr = lookupTable.loc[iTarget, 'LabelName:']
            in_file = saveFigDir / ('%s_%04d_%s' % (track_file.stem, iTarget, targetStr))
            print(in_file)
            combined_img += nib.load(in_file.with_suffix('.nii.gz')).get_fdata()
        combined_file  = saveFigDir / ('combined_aLIC_%04d_%s' % ( iTarget, targetStr))
        combinedNifti = nib.nifti1.Nifti1Image(combined_img, refT1.affine, refT1.header)
        print(combined_file)
        nib.save(combinedNifti, combined_file.with_suffix('.nii.gz'))
        

In [None]:
# load tractogram
# SKIPPED because we're iterating over multiple tractograms
tckIn=nib.streamlines.load(tckPath)

In [None]:
# orient all the steamlines, potentially not necessary given redundancy with 
# subsequent steps
# SKIPPED because we're using either_end selection

print("orienting streamlines")
orientedStreams=wmaPyTools.streamlineTools.orientAllStreamlines(tckIn.streamlines)

In [None]:
# save oriented streams
# SKIPPED
print("save oriented tck")
subTckSavePath=os.path.join(saveFigDir,'track_oriented.tck')
wmaPyTools.streamlineTools.stubbornSaveTractogram(orientedStreams,savePath=subTckSavePath)


In [None]:
# save lite streamlines
#SKIPPED
n_streams_to_keep = int(5E5)
select_bool = np.random.choice(range(len(orientedStreams)), n_streams_to_keep, replace=False)
lite_streams = orientedStreams[select_bool]
wmaPyTools.streamlineTools.stubbornSaveTractogram(lite_streams,
    savePath=str(Path(saveFigDir) / 'track_lite.tck'))


In [None]:
# select streamlines to work on

streams = lite_streams


In [None]:
from dipy.tracking.utils import density_map
from wmaPyTools.visTools import multiTileDensity

multiTileDensity(streams,refT1,saveFigDir,'density',densityThreshold=0,noEmpties=True)

In [None]:
M, grouping = utils.connectivity_matrix(streams, inflatedAtlas.affine, inflatedAtlas.get_fdata().astype(np.int),
                                        return_mapping=True,
                                        mapping_as_streamlines=True)

In [None]:
np.shape(M)

In [None]:
targets = targetLabels['left'] + targetLabels['right']
targets.sort()
for iTarget in targets:
    print(iTarget)
    print(lookupTable.loc[iTarget, 'LabelName:'])