In [4]:
import os
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter
from skimage.measure import marching_cubes
import trimesh
from sklearn.neighbors import NearestNeighbors

# Load Gene Positions
def load_gene_positions(csv_path):
    df = pd.read_csv(csv_path)
    return df[['middle_x', 'middle_y', 'middle_z']].dropna().values

# OPTION A: Estimate radius using nearest neighbor distance 
def estimate_radius_nnd(points, scale=1.5):
    nbrs = NearestNeighbors(n_neighbors=2).fit(points)
    distances, _ = nbrs.kneighbors(points)
    avg_dist = np.mean(distances[:, 1])
    return avg_dist * scale

# OPTION B: Estimate radius based on point count (volume heuristic)
def estimate_radius_by_count(points, k=5.0):
    n = len(points)
    return k / (n ** (1 / 3))


# Create scalar field via metaball union 
def create_metaball_field(points, grid_resolution=100, radius=5.0, grid_margin=10):
    mins = points.min(axis=0) - grid_margin
    maxs = points.max(axis=0) + grid_margin
    dims = (maxs - mins)
    spacing = dims / grid_resolution

    x = np.linspace(mins[0], maxs[0], grid_resolution)
    y = np.linspace(mins[1], maxs[1], grid_resolution)
    z = np.linspace(mins[2], maxs[2], grid_resolution)
    X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
    field = np.zeros_like(X)

    for p in points:
        d2 = (X - p[0])**2 + (Y - p[1])**2 + (Z - p[2])**2
        field += np.exp(-d2 / (2 * radius**2))

    return field, (x, y, z)

# automatically detect marching cube level 
def auto_select_level(field, grid_axes, level_range=(0.3, 0.7), steps=10):
    best_level = None
    best_score = -np.inf

    levels = np.linspace(level_range[0], level_range[1], steps)

    for lvl in levels:
        try:
            verts, faces, _, _ = marching_cubes(field, level=lvl)
            surface_area = trimesh.Trimesh(vertices=verts, faces=faces).area
            score = -abs(surface_area - 10000)  # target surface area
            if score > best_score:
                best_score = score
                best_level = lvl
        except Exception:
            continue

    return best_level
    
# Extract surface from scalar field
def extract_metaball_surface(field, grid_axes, level=0.5):
    verts, faces, _, _ = marching_cubes(field, level=level)
    x, y, z = grid_axes
    spacing = [(a[1] - a[0]) for a in grid_axes]
    origin = [a[0] for a in grid_axes]
    verts_world = verts * spacing + origin
    return trimesh.Trimesh(vertices=verts_world, faces=faces, process=False)

# Visualize with spheres and mesh
def view_metaball_with_points(mesh, points):
    scene = trimesh.Scene()

    # Make mesh translucent red
    mesh.visual.face_colors = [200, 30, 30, 25] 
    mesh.fix_normals()
    scene.add_geometry(mesh)

    # Add spheres for genes
    for pt in points:
        sphere = trimesh.creation.icosphere(radius=0.2, subdivisions=2)
        sphere.apply_translation(pt)
        sphere.visual.vertex_colors = [255, 255, 255, 255]  # white
        scene.add_geometry(sphere)

    # Optional: stronger lighting direction
    scene.ambient_light = [0.4, 0.4, 0.4, 1.0]

    return scene.show(jupyter=True)

# Main Pipeline 
def run_metaball_pipeline(csv_path, view=True, save_path=None):
    points = load_gene_positions(csv_path)

    # ----- radius methods -----
    # METHOD A: nearest neighbor
    radius = estimate_radius_nnd(points, scale=1.5)    
    # METHOD B: volume-based
    # radius = estimate_radius_by_count(points, k=5.0)       

    print(f"Using radius: {radius:.3f}")

    field, axes = create_metaball_field(points, grid_resolution=150, radius=radius)
    level = auto_select_level(field, axes)
    print(f"Level selected: {level}")
    mesh = extract_metaball_surface(field, axes, level=level)
    print(mesh.vertices.shape, mesh.faces.shape)

    if save_path:
        file_type = os.path.splitext(save_path)[-1].replace('.', '')
        mesh.export(save_path, file_type=file_type)
        print(f"Saved mesh to {save_path}")

    if view:
        return view_metaball_with_points(mesh, points)

#  Run 
csv_path = "data/green_monkey/all_structure_files/chr1/12hrs/vacv/structure_12hrs_vacv_gene_info.csv"
# obj_path = "data/green_monkey/all_structure_files/chr1/spatial_data/chr1_24hrs_vacv_metaball.obj"

# run_metaball_pipeline(csv_path, view=True, save_path=obj_path)
run_metaball_pipeline(csv_path, view=True)


Using radius: 0.259
Level selected: 0.3
(15262, 3) (30536, 3)
