In [None]:
# Add polar grid when wave is disabled to help visualize the inertial trajectory

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import os
import imageio.v2 as imageio
from tqdm import tqdm
import shutil

frames_dir = "./data/wave_frames"

shutil.rmtree(frames_dir)
os.makedirs(frames_dir, exist_ok=True)

# Parameters
i_movie = 3
num_frames = 60
radius_start, radius_end = 40, 200
amplitude_start, amplitude_end = 0.0, 0.1
particle_radius_start, particle_radius_end = 10, 11
alpha_0 = np.pi / 4
show_wave = True

x_start = radius_start * np.cos(alpha_0)
y_start = radius_start * np.sin(alpha_0)
x_end = x_start + 700
y_end = y_start + 40
m = (y_end - y_start) / (x_end - x_start)

def compute_alpha(r):
    A = 1 + m**2
    B = 2 * (x_start + y_start * m)
    C = x_start**2 + y_start**2 - r**2
    discriminant = B**2 - 4*A*C
    if discriminant < 0:
        return None
    delta = (-B + np.sqrt(discriminant)) / (2 * A)
    x = x_start + delta
    y = y_start + m * delta
    return np.arctan2(y, x)

def create_expanding_wave(radius, resolution=300, amplitude=1.0, sigma=10.0):
    theta = np.linspace(0, 2 * np.pi, resolution)
    r = np.linspace(0, radius * 1.2, resolution)
    R, Theta = np.meshgrid(r, theta)
    X = R * np.cos(Theta)
    Y = R * np.sin(Theta)
    Z = amplitude * np.exp(-((R - radius) ** 2) / (2 * sigma ** 2))
    return X, Y, Z

def add_particle(ax, radius, angle, wave_amplitude, sigma=10.0, particle_radius=0.2):
    u = np.linspace(0, 2 * np.pi, 200)
    v = np.linspace(0, np.pi, 20)
    x = particle_radius * np.outer(np.cos(u), np.sin(v))
    y = particle_radius * np.outer(np.sin(u), np.sin(v))
    z = particle_radius * np.outer(np.ones(np.size(u)), np.cos(v))
    x0 = radius * np.cos(angle)
    y0 = radius * np.sin(angle)
    z0 = wave_amplitude * np.exp(-((radius - radius) ** 2) / (2 * sigma ** 2))
    ax.plot_surface(x + x0, y + y0, z + z0, color='red')

def set_axes_equal(ax):
    limits = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()])
    spans = limits[:, 1] - limits[:, 0]
    centers = np.mean(limits, axis=1)
    radius = 0.5 * max(spans)
    for ctr, setter in zip(centers, [ax.set_xlim3d, ax.set_ylim3d, ax.set_zlim3d]):
        setter([ctr - radius, ctr + radius])

def draw_polar_grid(ax, max_radius, num_radial_lines=12, num_circles=6):
    angles = np.linspace(0, 2*np.pi, num_radial_lines, endpoint=False)
    radii = np.linspace(0, max_radius, num_circles + 1)[1:]

    for r in radii:
        theta = np.linspace(0, 2 * np.pi, 200)
        x = r * np.cos(theta)
        y = r * np.sin(theta)
        z = np.zeros_like(x) + 0.01
        ax.plot(x, y, z, color='gray', linewidth=0.5, alpha=0.5)

    for theta in angles:
        x = [0, max_radius * np.cos(theta)]
        y = [0, max_radius * np.sin(theta)]
        z = [0.01, 0.01]
        ax.plot(x, y, z, color='gray', linewidth=0.5, alpha=0.5)

# Output folders
frames_dir = "./data/wave_frames"
os.makedirs(frames_dir, exist_ok=True)
frame_paths = []

for i in tqdm(range(num_frames)):
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')

    interp = i / (num_frames - 1)
    current_radius = radius_start + interp * (radius_end - radius_start)
    current_amplitude = amplitude_start + interp * (amplitude_end - amplitude_start)
    current_particle_radius = particle_radius_start + interp * (particle_radius_end - particle_radius_start)
    current_angle = compute_alpha(current_radius)
    if current_angle is None:
        continue

    if show_wave:
        X, Y, Z = create_expanding_wave(current_radius, amplitude=current_amplitude, sigma=current_radius / 10)
        ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.7, linewidth=0)

    if not show_wave:
        draw_polar_grid(ax, max_radius=radius_end)

    add_particle(ax, current_radius, current_angle, wave_amplitude=current_amplitude,
                 sigma=current_radius / 10, particle_radius=current_particle_radius)

    if not show_wave:
        trail_r = np.linspace(radius_start, current_radius, i + 1)
        trail_angles = [compute_alpha(r) for r in trail_r]
        trail_x = [r * np.cos(a) for r, a in zip(trail_r, trail_angles) if a is not None]
        trail_y = [r * np.sin(a) for r, a in zip(trail_r, trail_angles) if a is not None]
        trail_z = [0.01] * len(trail_x)
        ax.scatter(trail_x, trail_y, trail_z, color='black', s=5)

    ax.set_xlim(-radius_end, radius_end)
    ax.set_ylim(-radius_end, radius_end)
    ax.set_zlim(0, 1.5)
    ax.axis('off')
    # ax.view_init(elev=90, azim=0)
    ax.view_init(elev=25, azim=0)
    set_axes_equal(ax)

    frame_path = os.path.join(frames_dir, f"frame_{i:03d}.png")
    plt.savefig(frame_path, dpi=100, bbox_inches='tight')
    frame_paths.append(frame_path)
    plt.close(fig)

# Export animation
gif_path = f"./data/inertial_motion_polar_grid_{i_movie}.gif"
mp4_path = f"./data/inertial_motion_polar_grid_{i_movie}.mp4"

with imageio.get_writer(gif_path, mode='I', duration=0.15) as writer:
    for path in frame_paths:
        writer.append_data(imageio.imread(path))

with imageio.get_writer(mp4_path, fps=30, codec='libx264') as writer:
    for path in frame_paths:
        writer.append_data(imageio.imread(path))

shutil.rmtree(frames_dir)
os.makedirs(frames_dir, exist_ok=True)
