In [6]:
import numpy as np
import plotly.graph_objects as go

# --- CONFIG ---
NPZ_FILE = "data/galaxy_sim_hecate.npz"
OUTPUT_HTML = "galaxy_sim.html"
MAX_FRAMES = 200  # reduce number for large datasets
LOG_EVERY = 10    # print every N frames

# --- Load simulation data ---
print("Loading simulation data...")
data = np.load(NPZ_FILE)
positions = data["body_positions"]        # shape (T, N, 3)
body_gal = data["body_galaxy"]            # shape (N,)
gal_pos = data["galaxy_positions"]        # final positions of galaxies
gal_mass = data["galaxy_masses"]
times = data["times"]
T, N, _ = positions.shape
print(f"Loaded positions: {positions.shape}, {T} timesteps, {N} bodies")

# --- Colors for galaxies ---
unique_gals = np.unique(body_gal)
colors = [f"hsl({i*360/len(unique_gals)},80%,50%)" for i in range(len(unique_gals))]

body_colors = np.array([colors[i] for i in body_gal])

# --- Build frames ---
print("Building frames...")
frames = []
step = max(1, T // MAX_FRAMES)
for t_idx, t in enumerate(range(0, T, step)):
    frame = go.Frame(
        data=[go.Scatter3d(
            x=positions[t, :, 0],
            y=positions[t, :, 1],
            z=positions[t, :, 2],
            mode='markers',
            marker=dict(size=3, color=body_colors)
        )],
        name=f"frame_{t_idx}"
    )
    frames.append(frame)
    if t_idx % LOG_EVERY == 0:
        print(f"Frame {t_idx}/{len(range(0,T,step))} created")

# --- Initial figure ---
print("Creating figure...")
fig = go.Figure(
    data=[go.Scatter3d(
        x=positions[0, :, 0],
        y=positions[0, :, 1],
        z=positions[0, :, 2],
        mode='markers',
        marker=dict(size=3, color=body_colors)
    )],
    layout=go.Layout(
        title="Galaxy Simulation",
        scene=dict(
            xaxis_title="X [AU]",
            yaxis_title="Y [AU]",
            zaxis_title="Z [AU]"
        ),
        updatemenus=[dict(
            type="buttons",
            showactive=False,
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, dict(frame=dict(duration=50, redraw=True),
                                           fromcurrent=True, mode="immediate")])]
        )]
    ),
    frames=frames
)

# --- Save HTML ---
print(f"Saving HTML animation to {OUTPUT_HTML} ...")
fig.write_html(OUTPUT_HTML)
print("Done! Open the HTML in a browser to see the animation.")


Loading simulation data...
Loaded positions: (1, 1000, 3), 1 timesteps, 1000 bodies
Building frames...
Frame 0/1 created
Creating figure...
Saving HTML animation to galaxy_sim.html ...
Done! Open the HTML in a browser to see the animation.
