In [None]:
import os 
import dipy 
from glob import glob

#Get all summed_*.trk from /home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/
bundles = glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/summed_*.trk')

cst_right = [bundles for bundles in bundles if 'CST_right' in bundles][0]
cst_left = [bundles for bundles in bundles if 'CST_left' in bundles][0]

print(cst_right)
print(cst_left)


In [None]:
# Function to find all thresholds that generate exactly 3 clusters
def find_all_thresholds_for_3_clusters(streamlines, max_threshold=30, step=0.5):
    working_thresholds = []
    
    for threshold in np.arange(step+18, max_threshold + step, step):
        qb = QuickBundles(threshold=threshold)
        clusters = qb.cluster(streamlines)
        print(f"Threshold: {threshold}, Number of clusters: {len(clusters)}")
        if len(clusters) == 3:
            working_thresholds.append(threshold)
    
    return working_thresholds

# Find all working thresholds for both CST tracts
thresholds_right_all = find_all_thresholds_for_3_clusters(cst_right_streamlines)
thresholds_left_all = find_all_thresholds_for_3_clusters(cst_left_streamlines)

print(f"CST Right - All thresholds that work: {thresholds_right_all}")
print(f"CST Left - All thresholds that work: {thresholds_left_all}")

# Find common thresholds that work for both
common_thresholds = list(set(thresholds_right_all) & set(thresholds_left_all))
common_thresholds.sort()

print(f"Common thresholds that work for both: {common_thresholds}")

if common_thresholds:
    largest_common_threshold = max(common_thresholds)
    print(f"Largest threshold that works for both CST tracts: {largest_common_threshold}")
    
    # Verify with the largest common threshold
    qb_common = QuickBundles(threshold=largest_common_threshold)
    clusters_right_common = qb_common.cluster(cst_right_streamlines)
    clusters_left_common = qb_common.cluster(cst_left_streamlines)
    
    print(f"Verification with threshold {largest_common_threshold}:")
    print(f"CST Right - Number of clusters: {len(clusters_right_common)}")
    print(f"CST Left - Number of clusters: {len(clusters_left_common)}")
else:
    print("No common thresholds found that work for both CST tracts")


### One central line

In [None]:
import numpy as np
from dipy.segment.clustering import QuickBundles
from dipy.io.streamline import load_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram,Space

output_dir="/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/central_line"
threshold=100.

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Process all bundles
for bundle_path in bundles:
    print(f"Processing {bundle_path} with threshold {threshold}")
    
    # Load tractogram
    tractogram = load_tractogram(bundle_path, reference='same',bbox_valid_check=False)
    streamlines = tractogram.streamlines
    
    # Apply QuickBundles clustering
    qb = QuickBundles(threshold=threshold)
    clusters = qb.cluster(streamlines)
    
    # Extract centroids
    centroids = [cluster.centroid for cluster in clusters]

    
    # Create a new tractogram with centroids
    centroids_tractogram = StatefulTractogram(centroids, tractogram, Space.RASMM)
    
    # Save centroids as .trk file
    bundle_name = bundle_path.split('/')[-1].replace('.trk', '')
    print(f"Bundle {bundle_name} - Number of clusters: {len(clusters)}")

    output_path = f"{output_dir}/{bundle_name}_centroids.trk"
    
    from dipy.io.streamline import save_tractogram
    save_tractogram(centroids_tractogram, output_path)
    
    print(f"Saved {len(centroids)} centroids to {output_path}")


In [None]:
import subprocess

ref = '/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/average_anat.nii.gz'
# Get all .trk files in the output directory
trk_files = glob(f"{output_dir}/*.trk")
out_flip_vtk = f"/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/central_line"

print(f"Found {len(trk_files)} .trk files in {output_dir}")

# Apply flip_tractogram command to each file
for trk_file in trk_files:
    out_vtk= f"{out_flip_vtk}/{os.path.basename(trk_file).replace('.trk', '.vtk')}"
    print(f'Processing {trk_file}...')

    # Run the flip_tractogram command
    result = subprocess.run(['flip_tractogram', trk_file, out_vtk,'--reference',ref], 
                            capture_output=True, text=True, check=True)
    print(f"Successfully flipped {trk_file}")


print("Finished processing all files")

## Long streamlines central line

In [None]:
import os 
import dipy 
from glob import glob

#Get all summed_*.trk from /home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/
bundles = glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/summed_*.trk')

cst_right = [bundles for bundles in bundles if 'CST_right' in bundles][0]
cst_left = [bundles for bundles in bundles if 'CST_left' in bundles][0]

print(cst_right)
print(cst_left)

In [None]:
import os
import numpy as np
from dipy.segment.clustering import QuickBundles
from dipy.io.streamline import load_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram,Space
import shapely
from shapely import LineString, MultiLineString, Polygon
output_dir="/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/long_central_line"
threshold=100.

def get_longest_streamlines(streamlines, number=15):
    """
    Trouve les N streamlines les plus longues.
    Parameters
    ----------
    streamlines : list
        Liste des streamlines (chaque streamline est un array numpy (N, 3))
    number : int
        Nombre de streamlines à retourner (par défaut 15)
    Returns
    -------
    longest_streamlines : list
        Liste des N streamlines les plus longues
    """ 
    # Calculer la longueur de chaque streamline avec shapely
    lengths = [LineString(s).length for s in streamlines]
    
    # Obtenir les indices des N plus longues
    longest_indices = np.argsort(lengths)[-number:]
    
    # Retourner les streamlines correspondantes
    longest_streamlines = [streamlines[i] for i in longest_indices]
    
    print(f"Streamlines les plus longues : {len(longest_streamlines)}")
    return longest_streamlines

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Process all bundles
for bundle_path in bundles:
    print(f"Processing {bundle_path} with threshold {threshold}")
    
    # Load tractogram
    tractogram = load_tractogram(bundle_path, reference='same',bbox_valid_check=False)
    streamlines = tractogram.streamlines

    streamlines = get_longest_streamlines(streamlines, number=int(0.1*len(streamlines)))
    
    # Apply QuickBundles clustering
    qb = QuickBundles(threshold=threshold)
    clusters = qb.cluster(streamlines)
    
    # Extract centroids
    centroids = [cluster.centroid for cluster in clusters]

    
    # Create a new tractogram with centroids
    centroids_tractogram = StatefulTractogram(centroids, tractogram, Space.RASMM)
    
    # Save centroids as .trk file
    bundle_name = bundle_path.split('/')[-1].replace('.trk', '')
    print(f"Bundle {bundle_name} - Number of clusters: {len(clusters)}")

    output_path = f"{output_dir}/{bundle_name}_centroids.trk"
    
    from dipy.io.streamline import save_tractogram
    save_tractogram(centroids_tractogram, output_path)
    
    print(f"Saved {len(centroids)} centroids to {output_path}")


In [None]:
import subprocess

ref = '/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/average_anat.nii.gz'
# Get all .trk files in the output directory
trk_files = glob(f"{output_dir}/*.trk")
out_flip_vtk = f"/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/long_central_line"

print(f"Found {len(trk_files)} .trk files in {output_dir}")

# Apply flip_tractogram command to each file
for trk_file in trk_files:
    out_vtk= f"{out_flip_vtk}/{os.path.basename(trk_file).replace('.trk', '.vtk')}"
    print(f'Processing {trk_file}...')

    # Run the flip_tractogram command
    result = subprocess.run(['flip_tractogram', trk_file, out_vtk,'--reference',ref], 
                            capture_output=True, text=True, check=True)
    print(f"Successfully flipped {trk_file}")


print("Finished processing all files")

## Multiple central lines

In [None]:
import numpy as np
from dipy.segment.clustering import QuickBundles
from dipy.io.streamline import load_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram,Space

output_dir="/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/centroids"
threshold=30.


import json

# Dictionary to store bundle names and their cluster counts
bundle_cluster_info = {}

# Process all bundles
for bundle_path in bundles:
    print(f"Processing {bundle_path} with threshold {threshold}")
    
    # Load tractogram
    tractogram = load_tractogram(bundle_path, reference='same',bbox_valid_check=False)
    streamlines = tractogram.streamlines
    
    # Apply QuickBundles clustering
    qb = QuickBundles(threshold=threshold)
    clusters = qb.cluster(streamlines)
    
    # Extract centroids
    centroids = [cluster.centroid for cluster in clusters]

    
    # Create a new tractogram with centroids
    centroids_tractogram = StatefulTractogram(centroids, tractogram, Space.RASMM)
    
    # Save centroids as .trk file
    bundle_name = bundle_path.split('/')[-1].replace('.trk', '')
    print(f"Bundle {bundle_name} - Number of clusters: {len(clusters)}")
    
    # Store bundle info
    bundle_cluster_info[bundle_name] = len(clusters)

    output_path = f"{output_dir}/{bundle_name}_centroids.vtk"
    
    save_tractogram(centroids_tractogram, output_path)
    
    print(f"Saved {len(centroids)} centroids to {output_path}")

# Save the bundle cluster information to JSON
json_output_path = f"{output_dir}/bundle_cluster_info.json"
with open(json_output_path, 'w') as json_file:
    json.dump(bundle_cluster_info, json_file, indent=2)

print(f"Saved bundle cluster information to {json_output_path}")

### Multiple clusters with Frechet distance

In [None]:
import os
import json
from glob import glob
import numpy as np
import vtk
from vtk.util.numpy_support import numpy_to_vtk
from dipy.segment.clustering import QuickBundles
from dipy.segment.featurespeed import Feature
from dipy.segment.metric import Metric
from dipy.tracking.streamline import Streamlines, set_number_of_points
from shapely.geometry import LineString
import shapely

def load_vtk_streamlines(vtk_file_path):
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(vtk_file_path)
    reader.Update()
    polydata = reader.GetOutput()
    lines = polydata.GetLines()
    streamlines = []
    lines.InitTraversal()
    id_list = vtk.vtkIdList()
    while lines.GetNextCell(id_list):
        line_points = []
        for j in range(id_list.GetNumberOfIds()):
            point_id = id_list.GetId(j)
            point = polydata.GetPoint(point_id)
            line_points.append(point)
        streamlines.append(np.array(line_points))
    return streamlines

# --- Réorientation par distance de Manhattan avec la streamline moyenne ---
def compute_mean_streamline(streamlines):
    qb = QuickBundles(threshold=1000)
    clusters = qb.cluster(streamlines)
    return clusters.centroids[0]

def reorient_to_reference_manhattan(streamlines, reference):
    if reference is None:
        return streamlines
    oriented = []
    for s in streamlines:
        if len(s) == 0:
            oriented.append(s)
            continue
        d_orig = np.sum(np.abs(s - reference))
        s_rev = s[::-1]
        d_flip = np.sum(np.abs(s_rev - reference))
        oriented.append(s_rev if d_flip < d_orig else s)
    return Streamlines(oriented)

class FrechetDistanceFeature(Feature):
    def infer_shape(self, datum):
        return np.asarray(datum).shape
    def extract(self, datum):
        return np.asarray(datum)

class FrechetDistanceMetric(Metric):
    def are_compatible(self, shape1, shape2):
        return True
    def distance(self, feature1, feature2):
        return shapely.hausdorff_distance(LineString(feature1), LineString(feature2))
    def dist(self, features1, features2):
        return self.distance(features1, features2)

output_dir = "/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/centroids_frechet"
os.makedirs(output_dir, exist_ok=True)
threshold = 30.0

feature = FrechetDistanceFeature()
metric = FrechetDistanceMetric(feature=feature)

bundle_cluster_info = {}
bundles = glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/vtk/summed_*.vtk')

for bundle_path in bundles:
    print('\n')
    bundle_name = os.path.basename(bundle_path).replace('.vtk', '')
    print(f"---------------- Processing {bundle_name} --------------")
    sl = load_vtk_streamlines(bundle_path)
    # Rééchantillonnage
    sl = set_number_of_points(Streamlines(sl), 12)
    # Calcul de la streamline moyenne et réorientation Manhattan
    mean_ref = compute_mean_streamline(sl)
    sl = reorient_to_reference_manhattan(sl, mean_ref)

    # Statistiques longueur (optionnel)
    lengths = []
    for s in sl:
        if len(s) > 0:
            lengths.append(LineString(s).length)
    if len(lengths) > 0:
        mean_length = np.mean(lengths)
        print(f"Mean streamline length for {bundle_name}: {mean_length:.2f} mm")

    print(f"Processing {bundle_name} with threshold {threshold} (Frechet)")

    max_clusters = 3

    if bundle_name.replace('left', 'right') in bundle_cluster_info.keys():
        max_clusters = bundle_cluster_info[bundle_name.replace('left', 'right')]
    elif bundle_name.replace('right', 'left') in bundle_cluster_info.keys():
        max_clusters = bundle_cluster_info[bundle_name.replace('right', 'left')]

    print(f"Max clusters for {bundle_name}: {max_clusters}")
    


    qb = QuickBundles(threshold=threshold, metric=metric,max_nb_clusters=max_clusters)
    clusters = qb.cluster(sl)


    #If the number of streamlines in a centroid is less than 1000, remove it
    bad_clusters = clusters.get_small_clusters(1000)
    for bad_cluster in bad_clusters:
        print(f"Removing bad cluster with {len(bad_cluster.indices)} streamlines for {bundle_name}")
        clusters.remove_cluster(bad_cluster)


    centroids = clusters.centroids


    print(f"Bundle {bundle_name} - Number of clusters (Frechet): {len(clusters)}")
    bundle_cluster_info[bundle_name] = len(clusters)

    # Ecriture des centroids en VTK
    centroid_polydata = vtk.vtkPolyData()
    centroid_points = vtk.vtkPoints()
    centroid_lines = vtk.vtkCellArray()
    centroid_polydata.SetPoints(centroid_points)
    centroid_indices = []
    for cid, c in enumerate(centroids):
        line = vtk.vtkPolyLine()
        line.GetPointIds().SetNumberOfIds(len(c))
        for i, p in enumerate(c):
            pid = centroid_points.InsertNextPoint(float(p[0]), float(p[1]), float(p[2]))
            line.GetPointIds().SetId(i, pid)
            centroid_indices.append(cid)
        centroid_lines.InsertNextCell(line)
    centroid_polydata.SetLines(centroid_lines)
    centroid_index_array = numpy_to_vtk(np.array(centroid_indices), deep=True)
    centroid_index_array.SetName('centroid_index')
    centroid_polydata.GetPointData().AddArray(centroid_index_array)
    centroid_writer = vtk.vtkPolyDataWriter()
    centroid_writer.SetFileName(os.path.join(output_dir, f"{bundle_name}_centroids.vtk"))
    centroid_writer.SetInputData(centroid_polydata)
    centroid_writer.Write()

    # Ecriture du modèle avec centroid_index et point_index
    model_polydata = vtk.vtkPolyData()
    model_points = vtk.vtkPoints()
    model_lines = vtk.vtkCellArray()
    model_polydata.SetPoints(model_points)

    # Mapping streamline -> cluster id
    streamline_cluster_ids = np.full(len(sl), -1, dtype=int)
    for cid, c in enumerate(clusters):
        for sidx in c.indices:
            streamline_cluster_ids[sidx] = cid

    model_centroid_indices = []
    model_point_indices = []
    for sidx, s in enumerate(sl):
        line = vtk.vtkPolyLine()
        line.GetPointIds().SetNumberOfIds(len(s))
        for i, p in enumerate(s):
            pid = model_points.InsertNextPoint(float(p[0]), float(p[1]), float(p[2]))
            line.GetPointIds().SetId(i, pid)
            model_centroid_indices.append(streamline_cluster_ids[sidx])
            model_point_indices.append(i)
        model_lines.InsertNextCell(line)
    model_polydata.SetLines(model_lines)
    arr_centroid_index = numpy_to_vtk(np.array(model_centroid_indices), deep=True)
    arr_centroid_index.SetName('centroid_index')
    model_polydata.GetPointData().AddArray(arr_centroid_index)
    arr_point_index = numpy_to_vtk(np.array(model_point_indices), deep=True)
    arr_point_index.SetName('point_index')
    model_polydata.GetPointData().AddArray(arr_point_index)
    model_writer = vtk.vtkPolyDataWriter()
    model_writer.SetFileName(os.path.join(output_dir, f"{bundle_name}_model_with_centroid_index.vtk"))
    model_writer.SetInputData(model_polydata)
    model_writer.Write()

json_output_path = os.path.join(output_dir, 'bundle_cluster_info.json')
with open(json_output_path, 'w') as f:
    json.dump(bundle_cluster_info, f, indent=2)
print(f"Saved bundle cluster information to {json_output_path}")

Saved bundle cluster information to /home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/centroids_frechet/bundle_cluster_info.json


### Test du pipeline d'association HCP

In [None]:
# Test du pipeline d'association avec un sujet
import sys
sys.path.append('/home/ndecaux/Code/actiDep')

from actiDep.pipeline.hcp_association import process_hcp_association, get_bundle_mapping
from actiDep.data.loader import Subject, Actidep

# Tester le mapping des bundles
bundle_mapping = get_bundle_mapping()
print("Bundle mapping disponible :")
for bundle, centroid in bundle_mapping.items():
    print(f"  {bundle} -> {centroid}")

# Tester avec un sujet
ds = Actidep('/home/ndecaux/NAS_EMPENN/share/projects/actidep/bids')
if len(ds.subject_ids) > 0:
    subject = ds.get_subject(ds.subject_ids[0])
    
    # Vérifier les fichiers VTK disponibles
    vtk_files = subject.get(
        pipeline='mcm_to_hcp_space',
        space='HCP', 
        extension='vtk',
        datatype='tracto'
    )
    
    print(f"\nFichiers VTK trouvés pour {subject.sub_id}: {len(vtk_files)}")
    for vtk_file in vtk_files:
        entities = vtk_file.get_entities()
        print(f"  Bundle: {entities.get('bundle', 'unknown')}")
    
    # Lancer le pipeline sur ce sujet
    if len(vtk_files) > 0:
        print(f"\nTraitement du sujet {subject.sub_id}...")
        process_hcp_association(subject)