In [None]:
import os
import vtk
import numpy as np
from glob import glob
from dipy.tracking.streamline import orient_by_streamline
vtk_files = glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/vtk/*.vtk')
vtk_files = [f for f in vtk_files if 'CST_left' in f]

ref_centerlines=glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/long_central_line/*.vtk')
ref_centerlines = [f for f in ref_centerlines if 'CST_left' in f]



def load_vtk_streamlines(vtk_file_path):
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(vtk_file_path)
    reader.Update()
    polydata = reader.GetOutput()
    lines = polydata.GetLines()
    streamlines = []
    lines.InitTraversal()
    id_list = vtk.vtkIdList()
    while lines.GetNextCell(id_list):
        line_points = []
        for j in range(id_list.GetNumberOfIds()):
            point_id = id_list.GetId(j)
            point = polydata.GetPoint(point_id)
            line_points.append(point)
        streamlines.append(np.array(line_points))
    return streamlines


sls= load_vtk_streamlines(vtk_files[0])

print([s.shape for s in sls])
sls_oriented = orient_by_streamline(sls, load_vtk_streamlines(ref_centerlines[0])[0])

print([s.shape for s in sls_oriented])


def get_normalized_index_streamlines(streamlines,n_bins=10):
    """
    Given a list of streamlines (each streamline is a numpy array of shape (N, 3)),
    returns a list of streamlines where each point is replaced by its normalized index along the streamline.
    Streamlines are assumed to be oriented consistently.
    """
    normalized_streamlines = []

    bin_edges = np.linspace(0, 1, n_bins + 1)
    print(bin_edges)
    for sl in streamlines:
        n_points = sl.shape[0]
        normalized_indices = np.linspace(0, 1, n_points)

        labelled_indices = np.digitize(normalized_indices, bin_edges, right=True)


        normalized_streamlines.append(labelled_indices)
    return normalized_streamlines

sls_normalized = get_normalized_index_streamlines(sls_oriented)

#Save sls_oriented as vtk files, with an additional scalar data array containing the normalized indices
output_dir = '/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/vtk_normalized_indices/'
os.makedirs(output_dir, exist_ok=True)
# Create vtkPoints for all streamlines
points = vtk.vtkPoints()
lines = vtk.vtkCellArray()
normalized_indices = vtk.vtkFloatArray()
normalized_indices.SetName('NormalizedIndex')

point_offset = 0
for sl,sl_norm in zip(sls_oriented, sls_normalized):
    # Add points
    for point in sl:
        points.InsertNextPoint(point)
    
    # Create line connectivity
    poly_line = vtk.vtkPolyLine()
    poly_line.GetPointIds().SetNumberOfIds(len(sl))
    for j in range(len(sl)):
        poly_line.GetPointIds().SetId(j, point_offset + j)
    lines.InsertNextCell(poly_line)
    
    # Add normalized indices for this streamline
    n_points = len(sl)
    for idx in sl_norm:
        normalized_indices.InsertNextValue(idx)
    
    point_offset += len(sl)

# Create single vtkPolyData
poly_data = vtk.vtkPolyData()
poly_data.SetPoints(points)
poly_data.SetLines(lines)
poly_data.GetPointData().AddArray(normalized_indices)

# Write to single file
writer = vtk.vtkPolyDataWriter()
bundle_name=os.path.basename(vtk_files[0].split('.')[0].split('summed_')[1])
writer.SetFileName(os.path.join(output_dir,f'normalized_{bundle_name}.vtk'))
writer.SetInputData(poly_data)
writer.Write()

print(f'Saved {len(sls_oriented)} oriented streamlines with normalized indices to {output_dir}')

[(25, 3), (21, 3), (23, 3), (29, 3), (28, 3), (23, 3), (29, 3), (24, 3), (21, 3), (23, 3), (23, 3), (23, 3), (29, 3), (26, 3), (26, 3), (22, 3), (23, 3), (24, 3), (24, 3), (28, 3), (22, 3), (23, 3), (23, 3), (27, 3), (27, 3), (21, 3), (26, 3), (20, 3), (30, 3), (24, 3), (21, 3), (29, 3), (25, 3), (22, 3), (28, 3), (24, 3), (22, 3), (24, 3), (22, 3), (28, 3), (21, 3), (19, 3), (31, 3), (18, 3), (26, 3), (24, 3), (22, 3), (23, 3), (25, 3), (21, 3), (24, 3), (30, 3), (21, 3), (25, 3), (26, 3), (22, 3), (31, 3), (20, 3), (21, 3), (26, 3), (27, 3), (21, 3), (31, 3), (22, 3), (25, 3), (22, 3), (27, 3), (23, 3), (20, 3), (24, 3), (24, 3), (27, 3), (26, 3), (28, 3), (19, 3), (29, 3), (28, 3), (29, 3), (26, 3), (17, 3), (25, 3), (18, 3), (23, 3), (20, 3), (23, 3), (20, 3), (26, 3), (22, 3), (25, 3), (19, 3), (22, 3), (21, 3), (29, 3), (26, 3), (22, 3), (26, 3), (22, 3), (24, 3), (20, 3), (25, 3), (24, 3), (29, 3), (19, 3), (24, 3), (19, 3), (27, 3), (22, 3), (29, 3), (26, 3), (27, 3), (21, 3), 

In [None]:
import os
import vtk
import numpy as np
from glob import glob
from dipy.tracking.streamline import orient_by_streamline
vtk_files = glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/vtk/*.vtk')
vtk_files = [f for f in vtk_files if 'CST_left' in f]

ref_centerlines=glob('/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/long_central_line/*.vtk')
ref_centerlines = [f for f in ref_centerlines if 'CST_left' in f]

def load_vtk_streamlines(vtk_file_path):
    reader = vtk.vtkPolyDataReader()
    reader.SetFileName(vtk_file_path)
    reader.Update()
    polydata = reader.GetOutput()
    lines = polydata.GetLines()
    streamlines = []
    lines.InitTraversal()
    id_list = vtk.vtkIdList()
    while lines.GetNextCell(id_list):
        line_points = []
        for j in range(id_list.GetNumberOfIds()):
            point_id = id_list.GetId(j)
            point = polydata.GetPoint(point_id)
            line_points.append(point)
        streamlines.append(np.array(line_points))
    return streamlines


sls= load_vtk_streamlines(vtk_files[0])

print([s.shape for s in sls])
sls_oriented = orient_by_streamline(sls, load_vtk_streamlines(ref_centerlines[0])[0])

print([s.shape for s in sls_oriented])


def get_normalized_index_streamlines(streamlines,n_bins=10):
    """
    Given a list of streamlines (each streamline is a numpy array of shape (N, 3)),
    returns a list of streamlines where each point is replaced by its normalized index along the streamline.
    Streamlines are assumed to be oriented consistently.
    """
    normalized_streamlines = []

    bin_edges = np.linspace(0, 1, n_bins + 1)
    print(bin_edges)
    for sl in streamlines:
        n_points = sl.shape[0]
        normalized_indices = np.linspace(0, 1, n_points)

        labelled_indices = np.digitize(normalized_indices, bin_edges, right=True)


        normalized_streamlines.append(labelled_indices)
    return normalized_streamlines

sls_normalized = get_normalized_index_streamlines(sls_oriented)

#Save sls_oriented as vtk files, with an additional scalar data array containing the normalized indices
output_dir = '/home/ndecaux/NAS_EMPENN/share/projects/HCP105_Zenodo_NewTrkFormat/inGroupe1Space/Atlas/flipped/vtk_normalized_indices/'
os.makedirs(output_dir, exist_ok=True)
# Create vtkPoints for all streamlines
points = vtk.vtkPoints()
lines = vtk.vtkCellArray()
normalized_indices = vtk.vtkFloatArray()
normalized_indices.SetName('NormalizedIndex')

point_offset = 0
for sl,sl_norm in zip(sls_oriented, sls_normalized):
    # Add points
    for point in sl:
        points.InsertNextPoint(point)
    
    # Create line connectivity
    poly_line = vtk.vtkPolyLine()
    poly_line.GetPointIds().SetNumberOfIds(len(sl))
    for j in range(len(sl)):
        poly_line.GetPointIds().SetId(j, point_offset + j)
    lines.InsertNextCell(poly_line)
    
    # Add normalized indices for this streamline
    n_points = len(sl)
    for idx in sl_norm:
        normalized_indices.InsertNextValue(idx)
    
    point_offset += len(sl)

# Create single vtkPolyData
poly_data = vtk.vtkPolyData()
poly_data.SetPoints(points)
poly_data.SetLines(lines)
poly_data.GetPointData().AddArray(normalized_indices)

# Write to single file
writer = vtk.vtkPolyDataWriter()
bundle_name=os.path.basename(vtk_files[0].split('.')[0].split('summed_')[1])
writer.SetFileName(os.path.join(output_dir,f'normalized_{bundle_name}.vtk'))
writer.SetInputData(poly_data)
writer.Write()

print(f'Saved {len(sls_oriented)} oriented streamlines with normalized indices to {output_dir}')

from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from vtk.util.numpy_support import numpy_to_vtk
import numpy as np
import vtk
import os
from skimage import measure
def train_svm_classifiers(streamlines, normalized_indices, n_bins=10):
    """
    Train multiple SVM classifiers to separate streamlines by their normalized indices.
    Each SVM is trained to distinguish a specific segment of the streamline.
    
    Parameters:
    -----------
    streamlines : list of numpy arrays
        List of streamlines, each with shape (N, 3)
    normalized_indices : list of numpy arrays
        List of bin indices for each streamline point (already binned 0 to n_bins)
    n_bins : int
        Number of bins (should match the max value in normalized_indices)
    
    Returns:
    --------
    classifiers : list of sklearn SVM classifiers
        Trained SVM classifiers for each bin
    scalers : list of sklearn StandardScaler
        Scalers for each bin
    bin_edges : numpy array
        Edges of the bins used for classification
    """
    
    # Create bin edges for reference
    bin_edges = np.linspace(0, 1, n_bins + 1)
    classifiers = []
    scalers = []
    
    # Collect all points with their bin indices
    all_points = []
    all_bin_indices = []
    
    for sl, bin_idx in zip(streamlines, normalized_indices):
        all_points.extend(sl)
        all_bin_indices.extend(bin_idx)
    
    all_points = np.array(all_points)
    all_bin_indices = np.array(all_bin_indices)
    
    # Train one SVM classifier for each bin
    for i in range(n_bins):
        # Create binary labels: 1 if point is in current bin, 0 otherwise
        labels = (all_bin_indices == i).astype(int)
        
        # Skip if bin is empty
        if np.sum(labels) == 0:
            classifiers.append(None)
            scalers.append(None)
            continue
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(all_points)
        
        # Train SVM classifier
        svm = SVC(kernel='rbf', C=1.0, gamma='scale', class_weight='balanced')
        svm.fit(X_scaled, labels)
        
        classifiers.append(svm)
        scalers.append(scaler)
        
        print(f"Trained SVM for bin {i}/{n_bins-1}, positive samples: {np.sum(labels)}")
    
    return classifiers, scalers, bin_edges

# Train SVM classifiers on the oriented streamlines
classifiers, scalers, bin_edges = train_svm_classifiers(sls_oriented[:1000], sls_normalized[:1000], n_bins=10)
print('classif done')

# Vectorize: collect all points first
all_points = np.vstack(sls_oriented)
num_points = all_points.shape[0]

# Assign a label to each point based on the SVM with the highest decision value (i.e., closest to the hyperplane)
decision_matrix = np.zeros((num_points, len(classifiers)))
for i, (svm, scaler) in enumerate(zip(classifiers, scalers)):
    if svm is not None:
        points_scaled = scaler.transform(all_points)
        decision_matrix[:, i] = svm.decision_function(points_scaled)
    else:
        decision_matrix[:, i] = -np.inf  # Mark as impossible

labels = np.argmax(decision_matrix, axis=1)

# Save labels in VTK
output_vtk_path = os.path.join(output_dir, f'svm_labels_{bundle_name}.vtk')

points = vtk.vtkPoints()
points.SetData(numpy_to_vtk(all_points, deep=True))

lines = vtk.vtkCellArray()
point_offset = 0
for sl in sls_oriented:
    n_points = len(sl)
    line_point_ids = vtk.vtkIdList()
    for j in range(n_points):
        line_point_ids.InsertNextId(point_offset + j)
    lines.InsertNextCell(line_point_ids)
    point_offset += n_points

label_array = numpy_to_vtk(labels.astype(np.int32), deep=True)
label_array.SetName('SVM_Label')

poly_data = vtk.vtkPolyData()
poly_data.SetPoints(points)
poly_data.SetLines(lines)
poly_data.GetPointData().AddArray(label_array)

writer = vtk.vtkPolyDataWriter()
writer.SetFileName(output_vtk_path)
writer.SetInputData(poly_data)
writer.Write()
print(f"Saved streamline point labels to {output_vtk_path}")

# Generate a 3D VTK of the SVM hyperplanes (decision boundaries)
# For each SVM, sample a grid in 3D and extract the zero-level set (decision boundary)

grid_size = 50
margin = 5
mins = all_points.min(axis=0) - margin
maxs = all_points.max(axis=0) + margin
X, Y, Z = np.meshgrid(
    np.linspace(mins[0], maxs[0], grid_size),
    np.linspace(mins[1], maxs[1], grid_size),
    np.linspace(mins[2], maxs[2], grid_size),
    indexing='ij'
)
grid_points = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T

hyperplane_polydata_list = []
for i, (svm, scaler) in enumerate(zip(classifiers, scalers)):
    if svm is None:
        continue
    grid_scaled = scaler.transform(grid_points)
    decision_values = svm.decision_function(grid_scaled)
    volume = decision_values.reshape(X.shape)
    # Extract the zero-level set (decision boundary)
    verts, faces, _, _ = measure.marching_cubes(volume, level=0)
    # Transform verts back to world coordinates
    verts_world = np.zeros_like(verts)
    for d in range(3):
        verts_world[:, d] = mins[d] + (maxs[d] - mins[d]) * verts[:, d] / (grid_size - 1)
    # Create VTK PolyData
    vtk_points = vtk.vtkPoints()
    vtk_points.SetData(numpy_to_vtk(verts_world, deep=True))
    vtk_faces = vtk.vtkCellArray()
    for face in faces:
        vtk_faces.InsertNextCell(3)
        for idx in face:
            vtk_faces.InsertCellPoint(int(idx))
    poly = vtk.vtkPolyData()
    poly.SetPoints(vtk_points)
    poly.SetPolys(vtk_faces)
    hyperplane_polydata_list.append(poly)
    # Save each hyperplane
    hyperplane_path = os.path.join(output_dir, f'svm_hyperplane_{bundle_name}_bin{i}.vtk')
    writer = vtk.vtkPolyDataWriter()
    writer.SetFileName(hyperplane_path)
    writer.SetInputData(poly)
    writer.Write()
    print(f"Saved SVM hyperplane for bin {i} to {hyperplane_path}")

Trained SVM for bin 0/9, positive samples: 1000
Trained SVM for bin 1/9, positive samples: 1907
Trained SVM for bin 2/9, positive samples: 2269
Trained SVM for bin 3/9, positive samples: 2216
Trained SVM for bin 4/9, positive samples: 2366
Trained SVM for bin 5/9, positive samples: 2415
Trained SVM for bin 6/9, positive samples: 2141
Trained SVM for bin 7/9, positive samples: 2260
Trained SVM for bin 8/9, positive samples: 2322
Trained SVM for bin 9/9, positive samples: 2163
classif done
