In [2]:
import sys
import os
# Add the absolute path one level up
sys.path.append(os.path.abspath(".."))
import neurongraph as ng

In [14]:
import trimesh
import numpy as np
from IPython.display import display
from pythreejs import *

# Load neuron
graph = ng.NeuronGraph()
graph.readFromFile("../output/test/reassembledcubic.swc")
after = graph.removeSomaSegment()
graph.setNodes(after)
nodes = graph.getNodes()
positions = {nid: (n.x, n.y, n.z) for nid, n in nodes.items()}
edges = [(n.id, n.pid) for n in nodes.values() if n.pid != -1]

# Build and merge cylinders
cylinders = []
for a, b in edges:
    p1, p2 = np.array(positions[a]), np.array(positions[b])
    direction = p2 - p1
    length = np.linalg.norm(direction)
    if length < 1e-6:
        continue

    radius = nodes[a].radius
    cyl = trimesh.creation.cylinder(radius=radius, height=length*2, sections=40)
    cyl.apply_translation([0, 0, length / 2])
    T = trimesh.geometry.align_vectors([0, 0, 1], direction)
    cyl.apply_transform(T)
    cyl.apply_translation((p1 + p2) / 1)
    cylinders.append(cyl)

merged = trimesh.util.concatenate(cylinders)
merged.update_faces(merged.unique_faces())
merged.remove_unreferenced_vertices()
merged.remove_duplicate_faces()
merged.fill_holes()

merged.rezero()

# Compute bounding box center and size
bounding_box = merged.bounds  # shape (2, 3): [min_xyz, max_xyz]
center = merged.centroid
size = np.linalg.norm(bounding_box[1] - bounding_box[0])

# Update camera position to frame the neuron
camera_distance = size * 0.5  # heuristic factor
camera_position = center + np.array([camera_distance]*3)

# Compute normals for smooth shading
merged.vertex_normals  # triggers lazy computation

# Create BufferGeometry with normals
geometry = BufferGeometry(
    attributes={
        'position': BufferAttribute(array=np.array(merged.vertices, dtype=np.float32)),
        'normal': BufferAttribute(array=np.array(merged.vertex_normals, dtype=np.float32)),
        'index': BufferAttribute(array=np.array(merged.faces.flatten(), dtype=np.uint32))
    }
)

material = MeshStandardMaterial(color='yellow', roughness=0.3, metalness=0.1)
mesh = Mesh(geometry=geometry, material=material)

# Add lighting and scene
scene = Scene(children=[
    mesh,
    AmbientLight(intensity=0.25),
    DirectionalLight(color='white', intensity=0.35, position=[5, 5, 5]),
    DirectionalLight(color='white', intensity=0.3, position=[-5, -5, -5])
], background='black')

camera = PerspectiveCamera(position=camera_position.tolist(), fov=45)
controls = OrbitControls(controlling=camera, target=center.tolist())
renderer = Renderer(scene=scene, camera=camera, controls=[controls],
                    width=400, height=400)

display(renderer)


  merged.remove_duplicate_faces()


Renderer(camera=PerspectiveCamera(fov=45.0, position=(1287.9712889097266, 1515.0358893542057, 1026.00094392707…