In [None]:
import os 
import dipy 
from glob import glob

#Get all summed_*.trk from /home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/
bundles = glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/summed_*.trk')

cst_right = [bundles for bundles in bundles if 'CST_right' in bundles][0]
cst_left = [bundles for bundles in bundles if 'CST_left' in bundles][0]

print(cst_right)
print(cst_left)


In [None]:
# Function to find all thresholds that generate exactly 3 clusters
def find_all_thresholds_for_3_clusters(streamlines, max_threshold=30, step=0.5):
    working_thresholds = []
    
    for threshold in np.arange(step+18, max_threshold + step, step):
        qb = QuickBundles(threshold=threshold)
        clusters = qb.cluster(streamlines)
        print(f"Threshold: {threshold}, Number of clusters: {len(clusters)}")
        if len(clusters) == 3:
            working_thresholds.append(threshold)
    
    return working_thresholds

# Find all working thresholds for both CST tracts
thresholds_right_all = find_all_thresholds_for_3_clusters(cst_right_streamlines)
thresholds_left_all = find_all_thresholds_for_3_clusters(cst_left_streamlines)

print(f"CST Right - All thresholds that work: {thresholds_right_all}")
print(f"CST Left - All thresholds that work: {thresholds_left_all}")

# Find common thresholds that work for both
common_thresholds = list(set(thresholds_right_all) & set(thresholds_left_all))
common_thresholds.sort()

print(f"Common thresholds that work for both: {common_thresholds}")

if common_thresholds:
    largest_common_threshold = max(common_thresholds)
    print(f"Largest threshold that works for both CST tracts: {largest_common_threshold}")
    
    # Verify with the largest common threshold
    qb_common = QuickBundles(threshold=largest_common_threshold)
    clusters_right_common = qb_common.cluster(cst_right_streamlines)
    clusters_left_common = qb_common.cluster(cst_left_streamlines)
    
    print(f"Verification with threshold {largest_common_threshold}:")
    print(f"CST Right - Number of clusters: {len(clusters_right_common)}")
    print(f"CST Left - Number of clusters: {len(clusters_left_common)}")
else:
    print("No common thresholds found that work for both CST tracts")


### One central line

In [None]:
import numpy as np
from dipy.segment.clustering import QuickBundles
from dipy.io.streamline import load_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram,Space

output_dir="/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/central_line"
threshold=100.

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Process all bundles
for bundle_path in bundles:
    print(f"Processing {bundle_path} with threshold {threshold}")
    
    # Load tractogram
    tractogram = load_tractogram(bundle_path, reference='same',bbox_valid_check=False)
    streamlines = tractogram.streamlines
    
    # Apply QuickBundles clustering
    qb = QuickBundles(threshold=threshold)
    clusters = qb.cluster(streamlines)
    
    # Extract centroids
    centroids = [cluster.centroid for cluster in clusters]

    
    # Create a new tractogram with centroids
    centroids_tractogram = StatefulTractogram(centroids, tractogram, Space.RASMM)
    
    # Save centroids as .trk file
    bundle_name = bundle_path.split('/')[-1].replace('.trk', '')
    print(f"Bundle {bundle_name} - Number of clusters: {len(clusters)}")

    output_path = f"{output_dir}/{bundle_name}_centroids.trk"
    
    from dipy.io.streamline import save_tractogram
    save_tractogram(centroids_tractogram, output_path)
    
    print(f"Saved {len(centroids)} centroids to {output_path}")


In [None]:
import subprocess

ref = '/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/average_anat.nii.gz'
# Get all .trk files in the output directory
trk_files = glob(f"{output_dir}/*.trk")
out_flip_vtk = f"/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/central_line"

print(f"Found {len(trk_files)} .trk files in {output_dir}")

# Apply flip_tractogram command to each file
for trk_file in trk_files:
    out_vtk= f"{out_flip_vtk}/{os.path.basename(trk_file).replace('.trk', '.vtk')}"
    print(f'Processing {trk_file}...')

    # Run the flip_tractogram command
    result = subprocess.run(['flip_tractogram', trk_file, out_vtk,'--reference',ref], 
                            capture_output=True, text=True, check=True)
    print(f"Successfully flipped {trk_file}")


print("Finished processing all files")

## Long streamlines central line

In [None]:
import os 
import dipy 
from glob import glob

#Get all summed_*.trk from /home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/
bundles = glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/summed_*.trk')

cst_right = [bundles for bundles in bundles if 'CST_right' in bundles][0]
cst_left = [bundles for bundles in bundles if 'CST_left' in bundles][0]

print(cst_right)
print(cst_left)

In [None]:
import os
import numpy as np
from dipy.segment.clustering import QuickBundles
from dipy.io.streamline import load_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram,Space
import shapely
from shapely import LineString, MultiLineString, Polygon
output_dir="/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/long_central_line"
threshold=100.

def get_longest_streamlines(streamlines, number=15):
    """
    Trouve les N streamlines les plus longues.
    Parameters
    ----------
    streamlines : list
        Liste des streamlines (chaque streamline est un array numpy (N, 3))
    number : int
        Nombre de streamlines à retourner (par défaut 15)
    Returns
    -------
    longest_streamlines : list
        Liste des N streamlines les plus longues
    """ 
    # Calculer la longueur de chaque streamline avec shapely
    lengths = [LineString(s).length for s in streamlines]
    
    # Obtenir les indices des N plus longues
    longest_indices = np.argsort(lengths)[-number:]
    
    # Retourner les streamlines correspondantes
    longest_streamlines = [streamlines[i] for i in longest_indices]
    
    print(f"Streamlines les plus longues : {len(longest_streamlines)}")
    return longest_streamlines

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Process all bundles
for bundle_path in bundles:
    print(f"Processing {bundle_path} with threshold {threshold}")
    
    # Load tractogram
    tractogram = load_tractogram(bundle_path, reference='same',bbox_valid_check=False)
    streamlines = tractogram.streamlines

    streamlines = get_longest_streamlines(streamlines, number=int(0.1*len(streamlines)))
    
    # Apply QuickBundles clustering
    qb = QuickBundles(threshold=threshold)
    clusters = qb.cluster(streamlines)
    
    # Extract centroids
    centroids = [cluster.centroid for cluster in clusters]

    
    # Create a new tractogram with centroids
    centroids_tractogram = StatefulTractogram(centroids, tractogram, Space.RASMM)
    
    # Save centroids as .trk file
    bundle_name = bundle_path.split('/')[-1].replace('.trk', '')
    print(f"Bundle {bundle_name} - Number of clusters: {len(clusters)}")

    output_path = f"{output_dir}/{bundle_name}_centroids.trk"
    
    from dipy.io.streamline import save_tractogram
    save_tractogram(centroids_tractogram, output_path)
    
    print(f"Saved {len(centroids)} centroids to {output_path}")


In [None]:
import subprocess

ref = '/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/average_anat.nii.gz'
# Get all .trk files in the output directory
trk_files = glob(f"{output_dir}/*.trk")
out_flip_vtk = f"/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/long_central_line"

print(f"Found {len(trk_files)} .trk files in {output_dir}")

# Apply flip_tractogram command to each file
for trk_file in trk_files:
    out_vtk= f"{out_flip_vtk}/{os.path.basename(trk_file).replace('.trk', '.vtk')}"
    print(f'Processing {trk_file}...')

    # Run the flip_tractogram command
    result = subprocess.run(['flip_tractogram', trk_file, out_vtk,'--reference',ref], 
                            capture_output=True, text=True, check=True)
    print(f"Successfully flipped {trk_file}")


print("Finished processing all files")

## Multiple central lines

In [None]:
import numpy as np
from dipy.segment.clustering import QuickBundles
from dipy.io.streamline import load_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram,Space

output_dir="/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/centroids"
threshold=30.


import json

# Dictionary to store bundle names and their cluster counts
bundle_cluster_info = {}

# Process all bundles
for bundle_path in bundles:
    print(f"Processing {bundle_path} with threshold {threshold}")
    
    # Load tractogram
    tractogram = load_tractogram(bundle_path, reference='same',bbox_valid_check=False)
    streamlines = tractogram.streamlines
    
    # Apply QuickBundles clustering
    qb = QuickBundles(threshold=threshold)
    clusters = qb.cluster(streamlines)
    
    # Extract centroids
    centroids = [cluster.centroid for cluster in clusters]

    
    # Create a new tractogram with centroids
    centroids_tractogram = StatefulTractogram(centroids, tractogram, Space.RASMM)
    
    # Save centroids as .trk file
    bundle_name = bundle_path.split('/')[-1].replace('.trk', '')
    print(f"Bundle {bundle_name} - Number of clusters: {len(clusters)}")
    
    # Store bundle info
    bundle_cluster_info[bundle_name] = len(clusters)

    output_path = f"{output_dir}/{bundle_name}_centroids.vtk"
    
    save_tractogram(centroids_tractogram, output_path)
    
    print(f"Saved {len(centroids)} centroids to {output_path}")

# Save the bundle cluster information to JSON
json_output_path = f"{output_dir}/bundle_cluster_info.json"
with open(json_output_path, 'w') as json_file:
    json.dump(bundle_cluster_info, json_file, indent=2)

print(f"Saved bundle cluster information to {json_output_path}")

### Multiple clusters with Frechet distance

In [1]:
import os
import json
from glob import glob
import numpy as np
import vtk
from vtk.util.numpy_support import numpy_to_vtk
from dipy.segment.clustering import QuickBundles
from dipy.segment.featurespeed import Feature
from dipy.segment.metric import Metric
from dipy.tracking.streamline import Streamlines, set_number_of_points
from shapely.geometry import LineString
import shapely
from multiprocessing import Pool, cpu_count

def load_vtk_streamlines(vtk_file_path):
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(vtk_file_path)
    reader.Update()
    polydata = reader.GetOutput()
    lines = polydata.GetLines()
    streamlines = []
    lines.InitTraversal()
    id_list = vtk.vtkIdList()
    while lines.GetNextCell(id_list):
        line_points = []
        for j in range(id_list.GetNumberOfIds()):
            point_id = id_list.GetId(j)
            point = polydata.GetPoint(point_id)
            line_points.append(point)
        streamlines.append(np.array(line_points))
    return streamlines

# --- Réorientation par distance de Manhattan avec la streamline moyenne ---
def compute_mean_streamline(streamlines):
    qb = QuickBundles(threshold=1000)
    clusters = qb.cluster(streamlines)
    return clusters.centroids[0]

def reorient_to_reference_manhattan(streamlines, reference):
    if reference is None:
        return streamlines
    oriented = []
    reorient_index=[]
    for i,s in enumerate(streamlines):
        if len(s) == 0:
            oriented.append(s)
            continue
        d_orig = np.sum(np.abs(s - reference))
        s_rev = s[::-1]
        d_flip = np.sum(np.abs(s_rev - reference))
        oriented.append(s_rev if d_flip < d_orig else s)
        reorient_index.append(True if d_flip < d_orig else False)
    return Streamlines(oriented), reorient_index

def reorient_from_indices(streamlines, reorient_indices):
    reoriented = []
    for s, flip in zip(streamlines, reorient_indices):
        if flip:
            reoriented.append(s[::-1])
        else:
            reoriented.append(s)
    return Streamlines(reoriented)

class FrechetDistanceFeature(Feature):
    def infer_shape(self, datum):
        return np.asarray(datum).shape
    def extract(self, datum):
        return np.asarray(datum)

class FrechetDistanceMetric(Metric):
    def are_compatible(self, shape1, shape2):
        return True
    def distance(self, feature1, feature2):
        return shapely.hausdorff_distance(LineString(feature1), LineString(feature2))
    def dist(self, features1, features2):
        return self.distance(features1, features2)

output_dir = "/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/centroids_frechet_"
os.makedirs(output_dir, exist_ok=True)
threshold = 30.0

feature = FrechetDistanceFeature()
metric = FrechetDistanceMetric(feature=feature)

bundle_cluster_info = {}
bundles = glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/vtk/summed_*.vtk')

def process_single_bundle(args):
    """Process a single bundle - designed for parallel execution"""
    bundle_path, output_dir, threshold, bundle_cluster_info_shared = args
    
    bundle_name = os.path.basename(bundle_path).replace('.vtk', '')
    print(f"\n---------------- Processing {bundle_name} --------------")
    
    sl_source = load_vtk_streamlines(bundle_path)
    # Rééchantillonnage
    sl = set_number_of_points(Streamlines(sl_source), 12)
    # Calcul de la streamline moyenne et réorientation Manhattan
    mean_ref = compute_mean_streamline(sl)
    sl ,reorient_index = reorient_to_reference_manhattan(sl, mean_ref)
    sl_source = reorient_from_indices(Streamlines(sl_source), reorient_index)
    # Statistiques longueur
    lengths = []
    for s in sl:
        if len(s) > 0:
            lengths.append(LineString(s).length)
    if len(lengths) > 0:
        mean_length = np.mean(lengths)
        print(f"Mean streamline length for {bundle_name}: {mean_length:.2f} mm")

    print(f"Processing {bundle_name} with threshold {threshold} (Frechet)")

    max_clusters = 3
    if bundle_name.replace('left', 'right') in bundle_cluster_info_shared:
        max_clusters = bundle_cluster_info_shared[bundle_name.replace('left', 'right')]
    elif bundle_name.replace('right', 'left') in bundle_cluster_info_shared:
        max_clusters = bundle_cluster_info_shared[bundle_name.replace('right', 'left')]

    print(f"Max clusters for {bundle_name}: {max_clusters}")

    qb = QuickBundles(threshold=threshold, metric=metric, max_nb_clusters=max_clusters)
    clusters = qb.cluster(sl)

    # Remove small clusters
    bad_clusters = clusters.get_small_clusters(1000)
    for bad_cluster in bad_clusters:
        print(f"Removing bad cluster with {len(bad_cluster.indices)} streamlines for {bundle_name}")
        clusters.remove_cluster(bad_cluster)

    centroids = clusters.centroids
    print(f"Bundle {bundle_name} - Number of clusters (Frechet): {len(clusters)}")

    # Ecriture des centroids en VTK
    centroid_polydata = vtk.vtkPolyData()
    centroid_points = vtk.vtkPoints()
    centroid_lines = vtk.vtkCellArray()
    centroid_polydata.SetPoints(centroid_points)
    centroid_indices = []
    for cid, c in enumerate(centroids):
        line = vtk.vtkPolyLine()
        line.GetPointIds().SetNumberOfIds(len(c))
        for i, p in enumerate(c):
            pid = centroid_points.InsertNextPoint(float(p[0]), float(p[1]), float(p[2]))
            line.GetPointIds().SetId(i, pid)
            centroid_indices.append(cid)
        centroid_lines.InsertNextCell(line)
    centroid_polydata.SetLines(centroid_lines)
    centroid_index_array = numpy_to_vtk(np.array(centroid_indices), deep=True)
    centroid_index_array.SetName('centroid_index')
    centroid_polydata.GetPointData().AddArray(centroid_index_array)
    centroid_writer = vtk.vtkPolyDataWriter()
    centroid_writer.SetFileName(os.path.join(output_dir, f"{bundle_name}_qbcentroids.vtk"))
    centroid_writer.SetInputData(centroid_polydata)
    centroid_writer.Write()

    # Ecriture du modèle avec centroid_index et point_index
    model_polydata = vtk.vtkPolyData()
    model_points = vtk.vtkPoints()
    model_lines = vtk.vtkCellArray()
    model_polydata.SetPoints(model_points)

    # Mapping streamline -> cluster id
    streamline_cluster_ids = np.full(len(sl), -1, dtype=int)
    for cid, c in enumerate(clusters):
        for sidx in c.indices:
            streamline_cluster_ids[sidx] = cid    
            
    model_centroid_indices = []
    model_point_indices = []
    for sidx, s in enumerate(sl_source):
        line = vtk.vtkPolyLine()
        line.GetPointIds().SetNumberOfIds(len(s))
        for i, p in enumerate(s):
            pid = model_points.InsertNextPoint(float(p[0]), float(p[1]), float(p[2]))
            line.GetPointIds().SetId(i, pid)
            model_centroid_indices.append(streamline_cluster_ids[sidx])
            model_point_indices.append(i)
        model_lines.InsertNextCell(line)
    model_polydata.SetLines(model_lines)
    arr_centroid_index = numpy_to_vtk(np.array(model_centroid_indices), deep=True)
    arr_centroid_index.SetName('centroid_index')
    model_polydata.GetPointData().AddArray(arr_centroid_index)
    arr_point_index = numpy_to_vtk(np.array(model_point_indices), deep=True)
    arr_point_index.SetName('point_index')
    model_polydata.GetPointData().AddArray(arr_point_index)
    model_writer = vtk.vtkPolyDataWriter()
    model_writer.SetFileName(os.path.join(output_dir, f"{bundle_name}_model_with_centroid_index.vtk"))
    model_writer.SetInputData(model_polydata)
    model_writer.Write()

    return bundle_name, len(clusters)

# Main execution
output_dir = "/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/centroids_frechet_parallel"
os.makedirs(output_dir, exist_ok=True)
threshold = 30.0

bundle_cluster_info = {}
bundles = glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/vtk/summed_*.vtk')

# Prepare arguments for parallel processing
args_list = [(bundle_path, output_dir, threshold, bundle_cluster_info) for bundle_path in bundles]

# Use multiprocessing to process bundles in parallel
n_processes = min(cpu_count() - 1, len(bundles))  # Leave one CPU free
print(f"Processing {len(bundles)} bundles using {n_processes} processes")

with Pool(processes=n_processes) as pool:
    results = pool.map(process_single_bundle, args_list)

# Collect results
for bundle_name, n_clusters in results:
    bundle_cluster_info[bundle_name] = n_clusters

json_output_path = os.path.join(output_dir, 'bundle_cluster_info.json')
with open(json_output_path, 'w') as f:
    json.dump(bundle_cluster_info, f, indent=2)
print(f"\nSaved bundle cluster information to {json_output_path}")

Processing 71 bundles using 39 processes

---------------- Processing summed_ICP_left --------------
---------------- Processing summed_OR_right --------------
---------------- Processing summed_ILF_right --------------
---------------- Processing summed_ILF_left --------------
---------------- Processing summed_SCP_left --------------
---------------- Processing summed_CST_right --------------
---------------- Processing summed_CST_left --------------
---------------- Processing summed_CC_1 --------------
---------------- Processing summed_OR_left --------------
---------------- Processing summed_SCP_right --------------
---------------- Processing summed_FPT_right --------------
---------------- Processing summed_FPT_left --------------


---------------- Processing summed_CC_3 --------------
---------------- Processing summed_IFO_left --------------

---------------- Processing summed_POPT_right --------------

---------------- Processing summed_POPT_left --------------
------------

In [1]:
import os
import json
from glob import glob
import numpy as np
import vtk
from vtk.util.numpy_support import vtk_to_numpy, numpy_to_vtk
from dipy.segment.clustering import QuickBundles
from dipy.segment.metric import AveragePointwiseEuclideanMetric
from dipy.tracking.streamline import Streamlines
from dipy.tracking.streamline import set_number_of_points

def load_vtk_with_indices(vtk_file_path):
    """Load VTK file and extract streamlines with their centroid_index"""
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(vtk_file_path)
    reader.Update()
    polydata = reader.GetOutput()
    
    # Get centroid_index array
    centroid_index_array = polydata.GetPointData().GetArray('centroid_index')
    if centroid_index_array is None:
        raise ValueError(f"No centroid_index found in {vtk_file_path}")
    
    centroid_indices = vtk_to_numpy(centroid_index_array)
    
    # Extract streamlines
    lines = polydata.GetLines()
    streamlines = []
    streamline_cluster_ids = []
    
    lines.InitTraversal()
    id_list = vtk.vtkIdList()
    point_idx = 0
    
    while lines.GetNextCell(id_list):
        line_points = []
        cluster_id = None
        
        for j in range(id_list.GetNumberOfIds()):
            point_id = id_list.GetId(j)
            point = polydata.GetPoint(point_id)
            line_points.append(point)
            
            if j == 0:
                cluster_id = centroid_indices[point_idx]
            point_idx += 1
        
        streamlines.append(np.array(line_points))
        streamline_cluster_ids.append(cluster_id)
    
    return streamlines, streamline_cluster_ids

def get_longest_streamlines(streamlines, percentile=85):
    """Get the top (100-percentile)% longest streamlines"""
    lengths = [len(s) for s in streamlines]
    threshold = np.percentile(lengths, percentile)
    return [s for s in streamlines if len(s) >= threshold]

def compute_centroid_from_cluster(cluster_streamlines, n_points=50):
    """
    Compute centroid using QuickBundles on the longest streamlines
    """
    # Keep only the 15% longest streamlines
    long_streamlines = get_longest_streamlines(cluster_streamlines, percentile=85)
    
    if len(long_streamlines) == 0:
        long_streamlines = cluster_streamlines
    
    print(f"      Using {len(long_streamlines)} longest streamlines (from {len(cluster_streamlines)} total)")
    
    # Resample all streamlines to same number of points
    resampled = set_number_of_points(Streamlines(long_streamlines), n_points)
    
    # Use QuickBundles with a single cluster to get the centroid
    metric = AveragePointwiseEuclideanMetric()
    qb = QuickBundles(threshold=np.inf, metric=metric)
    clusters = qb.cluster(resampled)
    
    # Get the centroid from QuickBundles
    centroid = clusters.centroids[0]
    
    return centroid

# Directories
input_dir = "/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/centroids_frechetlong2"
output_dir = "/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/centroids_frechetlong2"
os.makedirs(output_dir, exist_ok=True)

# Get all model files
model_files = glob(os.path.join(input_dir, '*_model_with_centroid_index.vtk'))
print(f"Found {len(model_files)} model files to process")

bundle_cluster_info = {}

for model_path in model_files:
    bundle_name = os.path.basename(model_path).replace('_model_with_centroid_index.vtk', '')
    print(f"\n{'='*60}")
    print(f"Processing {bundle_name}")
    print(f"{'='*60}")
    
    # Load streamlines with cluster assignments
    streamlines, cluster_ids = load_vtk_with_indices(model_path)
    print(f"Loaded {len(streamlines)} streamlines")
    
    # Group streamlines by cluster
    unique_clusters = sorted(set([c for c in cluster_ids if c >= 0]))
    print(f"Found {len(unique_clusters)} clusters")
    
    optimized_centroids = []
    
    for cluster_id in unique_clusters:
        # Get streamlines for this cluster
        cluster_streamlines = [streamlines[i] for i, cid in enumerate(cluster_ids) if cid == cluster_id]
        
        if len(cluster_streamlines) == 0:
            continue
        
        print(f"\n  Cluster {cluster_id}: {len(cluster_streamlines)} streamlines")
        
        # Compute centroid using QuickBundles
        print(f"    Computing centroid with QuickBundles...")
        centroid = compute_centroid_from_cluster(cluster_streamlines, n_points=500)
        optimized_centroids.append(centroid)
        
        print(f"    Centroid computed with {len(centroid)} points")
    
    bundle_cluster_info[bundle_name] = len(optimized_centroids)
    print(f"\nGenerated {len(optimized_centroids)} centroids for {bundle_name}")
    
    # Write centroids to VTK
    centroid_polydata = vtk.vtkPolyData()
    centroid_points = vtk.vtkPoints()
    centroid_lines = vtk.vtkCellArray()
    centroid_polydata.SetPoints(centroid_points)
    centroid_indices = []
    
    for cid, centroid in enumerate(optimized_centroids):
        line = vtk.vtkPolyLine()
        line.GetPointIds().SetNumberOfIds(len(centroid))
        for i, p in enumerate(centroid):
            pid = centroid_points.InsertNextPoint(float(p[0]), float(p[1]), float(p[2]))
            line.GetPointIds().SetId(i, pid)
            centroid_indices.append(cid)
        centroid_lines.InsertNextCell(line)
    
    centroid_polydata.SetLines(centroid_lines)
    centroid_index_array = numpy_to_vtk(np.array(centroid_indices), deep=True)
    centroid_index_array.SetName('centroid_index')
    centroid_polydata.GetPointData().AddArray(centroid_index_array)
    
    centroid_writer = vtk.vtkPolyDataWriter()
    centroid_writer.SetFileName(os.path.join(output_dir, f"{bundle_name}_centroids.vtk"))
    centroid_writer.SetInputData(centroid_polydata)
    centroid_writer.Write()
    
    print(f"Saved centroids to {output_dir}/{bundle_name}_centroids.vtk")

# Save bundle info
json_output_path = os.path.join(output_dir, 'bundle_cluster_info.json')
with open(json_output_path, 'w') as f:
    json.dump(bundle_cluster_info, f, indent=2)
print(f"\n{'='*60}")
print(f"Saved bundle cluster information to {json_output_path}")
print(f"{'='*60}")


Found 71 model files to process

Processing summed_CST_left
Loaded 32207 streamlines
Found 2 clusters

  Cluster 0: 3557 streamlines
    Computing centroid with QuickBundles...
      Using 653 longest streamlines (from 3557 total)
    Centroid computed with 500 points

  Cluster 1: 28650 streamlines
    Computing centroid with QuickBundles...
      Using 6105 longest streamlines (from 28650 total)
    Centroid computed with 500 points

Generated 2 centroids for summed_CST_left
Saved centroids to /home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/centroids_frechetlong2/summed_CST_left_centroids.vtk

Processing summed_OR_right
Loaded 23976 streamlines
Found 2 clusters

  Cluster 0: 21779 streamlines
    Computing centroid with QuickBundles...
      Using 3365 longest streamlines (from 21779 total)
    Centroid computed with 500 points

  Cluster 1: 2197 streamlines
    Computing centroid with QuickBundles...
      Using 340 longest streamlines