In [1]:
import pandas as pd
import plotly.graph_objects as go

In [2]:
right_root_ids = pd.read_csv(
    "adult_data/root_id_to_index.csv", dtype={"root_id": "string"}
)
all_neurons = (
    pd.read_csv("adult_data/classification_clean.csv", dtype={"root_id": "string"})
    .merge(right_root_ids, on="root_id")
    .fillna("Unknown")
)

all_coords = pd.read_csv("adult_data/all_coords_clean.csv", dtype={"root_id": "string"})

In [3]:
all = all_neurons.merge(all_coords, on="root_id")

In [4]:
# Set all cell_types with less than "n" samples to "others"
n = 100

counts = all["cell_type"].value_counts()

small_categories = counts[counts < n].index
all["cell_type"] = all["cell_type"].apply(
    lambda x: "others" if x in small_categories else x
)

In [10]:
df = all.copy()

fig = go.Figure()

# Create a trace for each cell type
cell_types = df["cell_type"].unique()
for cell_type in cell_types:
    subset = df[df["cell_type"] == cell_type]
    fig.add_trace(
        go.Scatter3d(
            x=subset["x"],
            y=subset["y"],
            z=subset["z"],
            mode="markers",
            marker=dict(size=2, opacity=0.6),
            name=cell_type,
            visible=True,
        )
    )

# Create buttons that will add interactivity
buttons = [
    dict(
        label="Show All",
        method="update",
        args=[{"visible": [True] * len(cell_types)}, {"title": "Showing: All Types"}],
    )
]
for cell_type in cell_types:
    buttons.append(
        dict(
            label=cell_type,
            method="update",
            args=[
                {"visible": [t == cell_type for t in cell_types]},
                {"title": f"Showing: {cell_type}"},
            ],
        )
    )

# Update the layout to include the button menu
fig.update_layout(
    scene=dict(
        xaxis=dict(title="X", showbackground=False, showgrid=False, zeroline=False),
        yaxis=dict(title="Y", showbackground=False, showgrid=False, zeroline=False),
        zaxis=dict(title="Z", showbackground=False, showgrid=False, zeroline=False),
        xaxis_title="X",
        yaxis_title="Y",
        zaxis_title="Z",
    ),
    updatemenus=[
        {
            "buttons": buttons,
            "direction": "down",
            "pad": {"r": 10, "t": 10},
            "showactive": True,
            "x": 0.1,
            "xanchor": "left",
            "y": 1.1,
            "yanchor": "top",
        }
    ],
    title="3D Scatter Plot by Cell Type",
)

fig.show()