In [None]:
import numpy as np
import open3d as o3d


In [None]:
def parse_electrode_file(filepath):
    """
    Reads electrode coordinate file and assigns label codes.
    Returns:
        coords: list of [x, y, z]
        labels: list of int labels (0, 1, 2, -2, 3)
    """
    coords = []
    labels = []

    label_mapping = {
        "lhj": 2,
        "rhj": -2,
        "nas": 3
    }

    with open(filepath, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) != 4:
                continue
            label_str, x, y, z = parts
            label = label_mapping.get(label_str, 1)  # Default to 1 (electrode)
            coords.append([float(x), float(y), float(z)])
            labels.append(label)
    
    return np.array(coords), np.array(labels)



def label_point_cloud(pcd_path, electrode_txt, radius=0.006, output_ply="labeled_output.ply", save_npz=True):
    coords, coord_labels = parse_electrode_file(electrode_txt)

    pcd = o3d.io.read_point_cloud(pcd_path)
    points = np.asarray(pcd.points)
    labels = np.zeros(len(points), dtype=np.int32)

    # Assign labels based on proximity to electrode coords
    for coord, label_value in zip(coords, coord_labels):
        dists = np.linalg.norm(points - coord, axis=1)
        labels[dists <= radius] = label_value

    # Assign colors for visualization
    colors = np.zeros_like(points)
    colors[labels == 0] = [0.1, 0.1, 0.1]   # Gray
    colors[labels == 1] = [1.0, 0.0, 0.0]   # Red
    colors[labels == 2] = [0.0, 1.0, 0.0]   # Green
    colors[labels == -2] = [0.0, 0.0, 1.0]  # Blue
    colors[labels == 3] = [1.0, 1.0, 0.0]   # Yellow

    # Create labeled point cloud
    labeled_pcd = o3d.geometry.PointCloud()
    labeled_pcd.points = o3d.utility.Vector3dVector(points)
    labeled_pcd.colors = o3d.utility.Vector3dVector(colors)
    o3d.io.write_point_cloud(output_ply, labeled_pcd)
    print(f" Labeled .ply saved: {output_ply}")

    if save_npz:
        np.savez(output_ply.replace('.ply', '.npz'), points=points, labels=labels)
        print(f" .npz file saved for training")

    return points, labels


In [None]:
label_point_cloud("path_file.ply", "path_file.txt", radius=0.006, output_ply="labeled_output.ply", save_npz=True)

 Labeled .ply saved: labeled_output.ply
 .npz file saved for training


(array([[-0.03696871,  0.19248752,  0.1387763 ],
        [-0.03675695,  0.19197108,  0.13898242],
        [-0.03596369,  0.19212349,  0.13924372],
        ...,
        [ 0.05704075,  0.09028162, -0.00309354],
        [ 0.05670792,  0.08974637, -0.00318378],
        [ 0.05684155,  0.09048887, -0.00364035]]),
 array([0, 0, 0, ..., 0, 0, 0], dtype=int32))

In [None]:
# To visualize the 3D PointCloud

# Load the training-ready .npz file
data = np.load("path_file.npz")
points = data["points"]
labels = data["labels"]

# Create point cloud
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)

# Define colors per label
colors = np.zeros_like(points)
colors[labels == 0] = [0.3, 0.3, 0.3]   # No electrode
colors[labels == 1] = [1.0, 0.0, 0.0]   # Regular electrode
colors[labels == 2] = [0.0, 1.0, 0.0]   # Left tragus (LPA)
colors[labels == -2] = [0.0, 0.0, 1.0]  # Right tragus (RPA)
colors[labels == 3] = [1.0, 1.0, 0.0]   # Nasion (Nz)

pcd.colors = o3d.utility.Vector3dVector(colors)

# Visualize
o3d.visualization.draw([pcd])


In [None]:
# Check what coordinates are not present as electrodes, as they are not included in the 3D PointCloud

def check_all_coordinates_labeled(ply_path, npz_path, coord_txt_path, radius=0.006):
    # Load point cloud
    pcd = o3d.io.read_point_cloud(ply_path)
    points = np.asarray(pcd.points)

    # Load labels
    data = np.load(npz_path)
    labels = data['labels']

    # Load electrode coordinates
    coords = []
    with open(coord_txt_path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) == 4:
                _, x, y, z = parts
                coords.append([float(x), float(y), float(z)])
    coords = np.array(coords)

    matched_coords = []
    unmatched_coords = []

    print("\n Checking distances for each coordinate:\n")
    for i, coord in enumerate(coords):
        dists = np.linalg.norm(points - coord, axis=1)
        min_dist = np.min(dists)
        within_radius = np.where(dists <= radius)[0]

        if np.any(labels[within_radius] != 0):
            matched_coords.append(i)
        else:
            unmatched_coords.append(i)
            print(f" Coordinate #{i} - No labeled point. Min distance: {min_dist:.6f} | Coord: {coord}")

    print(f"\n {len(matched_coords)} / {len(coords)} coordinates labeled at least one point.")
    if unmatched_coords:
        print(f"\n {len(unmatched_coords)} unmatched coordinates total.")

    return matched_coords, unmatched_coords



In [None]:
check_all_coordinates_labeled(
    ply_path="path_file.ply",
    npz_path="path_file.npz",
    coord_txt_path="path_file.txt",
    radius=0.006
)



📏 Checking distances for each coordinate:

❌ Coordinate #47 - No labeled point. Min distance: 0.032442 | Coord: [0.065217 0.048208 0.055764]
❌ Coordinate #48 - No labeled point. Min distance: 0.010665 | Coord: [0.065371 0.070002 0.031842]
❌ Coordinate #55 - No labeled point. Min distance: 0.008176 | Coord: [0.054304 0.072238 0.004747]
❌ Coordinate #62 - No labeled point. Min distance: 0.007531 | Coord: [ 0.040542  0.072951 -0.014787]
❌ Coordinate #118 - No labeled point. Min distance: 0.020703 | Coord: [-0.068504  0.060436  0.084265]

✅ 125 / 130 coordinates labeled at least one point.

⚠️ 5 unmatched coordinates total.


([0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  49,
  50,
  51,
  52,
  53,
  54,
  56,
  57,
  58,
  59,
  60,
  61,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,
  111,
  112,
  113,
  114,
  115,
  116,
  117,
  119,
  120,
  121,
  122,
  123,
  124,
  125,
  126,
  127,
  128,
  129],
 [47, 48, 55, 62, 118])