In [4]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, FFMpegWriter
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# ----------------------------
# Parameters
# ----------------------------
N = 9
FPS = 60
DURATION = 10
STEPS = FPS * DURATION
DT = 1 / FPS

SIZE = 0.05
SPEED = 0.5
INERTIA = 0.9
WIGGLE = 0.5

# Node selection
NODE_A = 0  # A=0, B=1, C=2

# If you want random motion every run, set seed=None; for repeatability set an int
SEED = None
if SEED is not None:
    np.random.seed(SEED)

# ----------------------------
# Helper functions
# ----------------------------
def normalize(v):
    n = np.linalg.norm(v)
    return v / n if n > 1e-9 else v

def fish_triangle(pos, vel):
    """Simple 'fish' triangle aligned with velocity."""
    f = normalize(vel)
    up = np.array([0.0, 0.0, 1.0])
    if abs(np.dot(f, up)) > 0.9:
        up = np.array([0.0, 1.0, 0.0])

    right = normalize(np.cross(f, up))

    tip = pos + f * SIZE
    base = pos - f * SIZE * 0.6

    return np.vstack([
        tip,
        base - right * SIZE * 0.5,
        base + right * SIZE * 0.5
    ])

def bounce(p, v):
    """Reflect off unit cube walls."""
    for i in range(3):
        if p[i] < 0.0:
            p[i] = 0.0
            v[i] *= -1.0
        elif p[i] > 1.0:
            p[i] = 1.0
            v[i] *= -1.0
    return p, v

# ----------------------------
# Initialize state
# ----------------------------
pos = np.random.rand(N, 3)
vel = np.random.randn(N, 3)
vel = np.array([normalize(v) for v in vel]) * SPEED

# ----------------------------
# Figure
# ----------------------------
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.set_box_aspect((1, 1, 1))
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_zlim(0, 1)
ax.set_title("N Fish in a Unit Cube (Nearest-Neighbor Link from Node A)")

# Optional: stable view angle
ax.view_init(elev=20, azim=-60)

# Cube wireframe
edges = [
    [(0,0,0),(1,0,0)],[(1,0,0),(1,1,0)],[(1,1,0),(0,1,0)],[(0,1,0),(0,0,0)],
    [(0,0,1),(1,0,1)],[(1,0,1),(1,1,1)],[(1,1,1),(0,1,1)],[(0,1,1),(0,0,1)],
    [(0,0,0),(0,0,1)],[(1,0,0),(1,0,1)],[(1,1,0),(1,1,1)],[(0,1,0),(0,1,1)]
]
for a, b in edges:
    ax.plot([a[0], b[0]], [a[1], b[1]], [a[2], b[2]], lw=1)

# Fish patches
patches = []
for i in range(N):
    tri = fish_triangle(pos[i], vel[i])
    poly = Poly3DCollection([tri], alpha=0.85)
    ax.add_collection3d(poly)
    patches.append(poly)

# Color nodes (run once)
for i, poly in enumerate(patches):
    if i == NODE_A:
        poly.set_facecolor("red")
    else:
        poly.set_facecolor("tab:blue")

# Nearest-neighbor line (A -> nearest of {B,C})
nearest_line, = ax.plot([], [], [], color="red", lw=2)

def update(_frame):
    global pos, vel

    # ---- Fish motion ----
    for i in range(N):
        noise = normalize(np.random.randn(3)) * (WIGGLE * SPEED)
        vel[i] = normalize(INERTIA * vel[i] + (1.0 - INERTIA) * noise) * SPEED
        pos[i] = pos[i] + vel[i] * DT
        pos[i], vel[i] = bounce(pos[i], vel[i])
        patches[i].set_verts([fish_triangle(pos[i], vel[i])])

    # ---- Nearest-neighbor logic (Node A) ----
    a = NODE_A
    others = [i for i in range(N) if i != a]

    dists2 = [np.sum((pos[a] - pos[i])**2) for i in others]
    nearest = others[int(np.argmin(dists2))]

    xs = [pos[a, 0], pos[nearest, 0]]
    ys = [pos[a, 1], pos[nearest, 1]]
    zs = [pos[a, 2], pos[nearest, 2]]

    nearest_line.set_data(xs, ys)
    nearest_line.set_3d_properties(zs)

    return patches + [nearest_line]

plt.close(fig)
ani = FuncAnimation(fig, update, frames=STEPS, interval=1000/FPS, blit=False)

# ---- Save to MP4 using ffmpeg ----
try:
    writer = FFMpegWriter(fps=FPS, bitrate=1800)
    ani.save("fish_tracking_N.mp4", writer=writer)
    print("fish_tracking_N.mp4 written (overwritten each run)")
except RuntimeError as e:
    print("Could not write MP4 via ffmpeg. Error:")
    print(e)



fish_tracking_N.mp4 written (overwritten each run)
