In [8]:
import os
import numpy as np
import pandas as pd
from scipy.ndimage import gaussian_filter
from skimage.measure import marching_cubes
import trimesh
from sklearn.neighbors import NearestNeighbors

# Load Gene Positions 
def load_gene_positions(csv_path):
    df = pd.read_csv(csv_path)
    return df[['middle_x', 'middle_y', 'middle_z']].dropna().values

# OPTION A: Estimate radius using nearest neighbor distance
def estimate_radius_nnd(points, scale=1.5):
    nbrs = NearestNeighbors(n_neighbors=2).fit(points)
    distances, _ = nbrs.kneighbors(points)
    avg_dist = np.mean(distances[:, 1])
    return avg_dist * scale

# OPTION B: Estimate radius based on point count (volume heuristic)
def estimate_radius_by_count(points, k=5.0):
    n = len(points)
    return k / (n ** (1 / 3))


# Create scalar field via metaball union
def create_metaball_field(points, grid_resolution=100, radius=5.0, grid_margin=10):
    mins = points.min(axis=0) - grid_margin
    maxs = points.max(axis=0) + grid_margin
    dims = (maxs - mins)
    spacing = dims / grid_resolution

    x = np.linspace(mins[0], maxs[0], grid_resolution)
    y = np.linspace(mins[1], maxs[1], grid_resolution)
    z = np.linspace(mins[2], maxs[2], grid_resolution)
    X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
    field = np.zeros_like(X)

    for p in points:
        d2 = (X - p[0])**2 + (Y - p[1])**2 + (Z - p[2])**2
        field += np.exp(-d2 / (2 * radius**2))

    return field, (x, y, z)

# automatically detect marching cube level
def auto_select_level(field, grid_axes, level_range=(0.3, 0.7), steps=10):
    best_level = None
    best_score = -np.inf

    levels = np.linspace(level_range[0], level_range[1], steps)

    for lvl in levels:
        try:
            verts, faces, _, _ = marching_cubes(field, level=lvl)
            surface_area = trimesh.Trimesh(vertices=verts, faces=faces).area
            score = -abs(surface_area - 10000)  # target surface area
            if score > best_score:
                best_score = score
                best_level = lvl
        except Exception:
            continue

    return best_level
    
# Extract surface from scalar field
def extract_metaball_surface(field, grid_axes, level=0.5):
    verts, faces, _, _ = marching_cubes(field, level=level)
    x, y, z = grid_axes
    spacing = [(a[1] - a[0]) for a in grid_axes]
    origin = [a[0] for a in grid_axes]
    verts_world = verts * spacing + origin
    return trimesh.Trimesh(vertices=verts_world, faces=faces, process=False)


# Main Pipeline
def run_metaball_pipeline(csv_path, save_path=None):
    points = load_gene_positions(csv_path)

    # ----- radius methods -----
    # METHOD A: nearest neighbor
    radius = estimate_radius_nnd(points, scale=1.5)    
    # METHOD B: volume-based
    # radius = estimate_radius_by_count(points, k=5.0)       

    print(f"Using radius: {radius:.3f}")

    field, axes = create_metaball_field(points, grid_resolution=150, radius=radius)
    level = auto_select_level(field, axes)
    print(f"Level selected: {level}")
    mesh = extract_metaball_surface(field, axes, level=level)

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)  # <-- Ensure folder exists
        file_type = os.path.splitext(save_path)[-1].replace('.', '')
        mesh.export(save_path, file_type=file_type)
        print(f"Saved mesh to {save_path}")

def main():
    base_dir = "data/green_monkey/all_structure_files/"
    for chrom in os.listdir(base_dir):
        chrom_path = os.path.join(base_dir, chrom)
        if not os.path.isdir(chrom_path):
            continue
        print(f"Processing {chrom}...")
        for hrs in os.listdir(chrom_path):
            hrs_path = os.path.join(chrom_path, hrs)
            if not os.path.isdir(hrs_path):
                continue
            for cond in ['vacv', 'untr']:
                cond_path = os.path.join(hrs_path, cond)
                if not os.path.isdir(cond_path):
                    continue
    
                fp = os.path.join(cond_path, f'structure_{hrs}_{cond}_gene_info.csv')
                if not os.path.exists(fp):
                    continue
                print(fp)
                out_dir = os.path.join(base_dir, chrom, 'spatial_data', 'overall_shapes')
                os.makedirs(out_dir, exist_ok=True)  # <-- Create folder if missing
                obj_path = os.path.join(out_dir, f"{chrom}_{hrs}_{cond}_metaball.obj")

                run_metaball_pipeline(fp, save_path=obj_path)

if __name__ == "__main__":
    main()
    
# # ---- Run ----
# csv_path = "data/green_monkey/all_structure_files/chr1/24hrs/vacv/structure_24hrs_vacv_gene_info.csv"
# obj_path = "data/green_monkey/all_structure_files/chr1/spatial_data/chr1_24hrs_vacv_metaball.obj"

# run_metaball_pipeline(csv_path, view=True, save_path=obj_path)


Processing chr23...
data/green_monkey/all_structure_files/chr23/24hrs/vacv/structure_24hrs_vacv_gene_info.csv
Using radius: 0.302
Level selected: 0.3
Saved mesh to data/green_monkey/all_structure_files/chr23/spatial_data/overall_shapes/chr23_24hrs_vacv_metaball.obj
data/green_monkey/all_structure_files/chr23/24hrs/untr/structure_24hrs_untr_gene_info.csv
Using radius: 0.305
Level selected: 0.3
Saved mesh to data/green_monkey/all_structure_files/chr23/spatial_data/overall_shapes/chr23_24hrs_untr_metaball.obj
data/green_monkey/all_structure_files/chr23/12hrs/vacv/structure_12hrs_vacv_gene_info.csv
Using radius: 0.316
Level selected: 0.3
Saved mesh to data/green_monkey/all_structure_files/chr23/spatial_data/overall_shapes/chr23_12hrs_vacv_metaball.obj
data/green_monkey/all_structure_files/chr23/12hrs/untr/structure_12hrs_untr_gene_info.csv
Using radius: 0.310
Level selected: 0.3
Saved mesh to data/green_monkey/all_structure_files/chr23/spatial_data/overall_shapes/chr23_12hrs_untr_metaball.