In [None]:

import pandas as pd

def load_swc_to_graph(swc_file):
    df = pd.read_csv(
        swc_file,
        comment='#',
        delim_whitespace=True,
        names=["id", "type", "x", "y", "z", "radius", "parent"]
    )

    G = nx.Graph()
    node_positions = {}

    for _, row in df.iterrows():
        node_id = int(row["id"])
        parent_id = int(row["parent"])
        coord = np.array([row["x"], row["y"], row["z"]])
        node_positions[node_id] = coord
        G.add_node(node_id)
        if parent_id != -1:
            parent_coord = node_positions[parent_id]
            dist = np.linalg.norm(coord - parent_coord)
            G.add_edge(parent_id, node_id, weight=dist)

    return G, node_positions


def compute_synaptic_embeddings(
    client,
    segment_id,
    nucleus_id,
    skel_path,
    k=20,
    alpha=1.0,
    synapse_table_name='synapse_table',
    output_csv_path=None,
):
    """
    Compute synaptic embeddings for a given neuron based on its morphology and synapses.

    Parameters:
    - client: Initialized CAVEclient object
    - segment_id: Root ID of the neuron
    - nucleus_id: Nucleus ID (for SWC filename)
    - skel_path: Path to the skeleton SWC file
    - k: Number of Laplacian modes (default 20)
    - alpha: Fractional exponent (currently unused, placeholder for future extension)
    - synapse_table_name: Name of the synapse table
    - output_csv_path: If given, save the embeddings to this CSV file

    Returns:
    - DataFrame of synapse embeddings
    """
    import pandas as pd
    import numpy as np
    import networkx as nx
    from scipy.sparse.linalg import eigsh
    from sklearn.neighbors import KDTree
    from neuron_morphology.swc_io import read_swc

    # Step 1: Load skeleton and build graph
    skel_file = f"{skel_path}/{segment_id}.swc"

    G, node_positions = load_swc_to_graph(skel_file)

    for src, tgt in edges:
        dist = np.linalg.norm(node_positions[src] - node_positions[tgt])
        G.add_edge(src, tgt, weight=dist)

    # Step 2: Compute Laplacian and eigenvectors
    L = nx.laplacian_matrix(G, weight='weight')
    eigenvalues, eigenvectors = eigsh(L, k=k, which='SM')

    # Step 3: Query only relevant synapses
    synapse_df = client.materialize.query_table(
        synapse_table_name,
        split_positions=True,
        filter_in_dict={'post_pt_root_id': [segment_id]}
    )

    if synapse_df.empty:
        print(f"No synapses found for segment ID {segment_id}.")
        return pd.DataFrame()

    # Step 4: Map synapses to nearest nodes
    skeleton_coords = np.array(list(node_positions.values()))
    tree = KDTree(skeleton_coords)
    synapse_coords = synapse_df[['x', 'y', 'z']].values
    distances, indices = tree.query(synapse_coords, k=1)
    nearest_nodes = [list(node_positions.keys())[idx[0]] for idx in indices]

    # Step 5: Embed synapses
    embeddings = []
    for node_id in nearest_nodes:
        node_index = list(G.nodes).index(node_id)
        embedding = eigenvectors[node_index] / eigenvalues
        embeddings.append(embedding)

    embedding_df = pd.DataFrame(embeddings, columns=[f'mode_{i+1}' for i in range(k)])
    embedding_df['synapse_id'] = synapse_df['id'].values

    if output_csv_path:
        embedding_df.to_csv(output_csv_path, index=False)

    return embedding_df

from caveclient import CAVEclient

from caveclient import CAVEclient
from meshparty import skeleton_io
import os

# Initialize CAVEclient
client = CAVEclient('minnie65_public')

# Segment ID of the neuron
segment_id = 864691135122603047

# Directory to save
save_dir = "/data"
os.makedirs(save_dir, exist_ok=True)

# Pull the skeleton
skeleton = client.materialize.get_skeleton(segment_id)

# Save as SWC file
swc_path = os.path.join(save_dir, f"{segment_id}.swc")
skeleton_io.write_swc(skeleton, swc_path)

print(f"Saved SWC to: {swc_path}")


client = CAVEclient('minnie65_public')
segment_id = 864691135122603047
nucleus_id = 292685
skel_path = "/data"

df = compute_synaptic_embeddings(
    client,
    segment_id=segment_id,
    nucleus_id=nucleus_id,
    skel_path=skel_path,
    k=20,
    output_csv_path="synaptic_embedding.csv"
)


ModuleNotFoundError: No module named 'meshparty'