In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Circle
from matplotlib.animation import FFMpegWriter
# --- Helper functions ---

def compute_accelerations(positions, velocities):
    """Compute accelerations from friction (opposite to velocity)."""
    accelerations = np.zeros_like(positions)
    for i in range(len(positions)):
        # Simple friction: -friction * velocity
        friction_force = -friction * velocities[i]
        accelerations[i] = friction_force
    return accelerations

def check_border_collision(positions, velocities):
    for i in range(len(positions)):
        # X border collisions
        if positions[i, 0] <= ball_radius:
            velocities[i, 0] = -velocities[i, 0] * 0.9
            positions[i, 0] = ball_radius
        elif positions[i, 0] >= table_width - ball_radius:
            velocities[i, 0] = -velocities[i, 0] * 0.9
            positions[i, 0] = table_width - ball_radius
        
        # Y border collisions
        if positions[i, 1] <= ball_radius:
            velocities[i, 1] = -velocities[i, 1] * 0.9
            positions[i, 1] = ball_radius
        elif positions[i, 1] >= table_height - ball_radius:
            velocities[i, 1] = -velocities[i, 1] * 0.9
            positions[i, 1] = table_height - ball_radius

def check_ball_collision(positions, velocities):
    n = len(positions)
    for i in range(n):
        for j in range(i + 1, n):
            diff = positions[j] - positions[i]
            dist = np.linalg.norm(diff)
            if dist < 2*ball_radius:
                # Simple collision response (elastic with some dissipation)
                normal = diff / dist
                relative_velocity = velocities[i] - velocities[j]
                vel_along_normal = np.dot(relative_velocity, normal)
                # Coefficient of restitution
                e = 0.9
                impulse = (2 * vel_along_normal) / (masses[i] + masses[j])
                velocities[i] -= impulse * masses[j] * normal * e
                velocities[j] += impulse * masses[i] * normal * e
                
                # Separate them so they don't overlap
                overlap = 2 * ball_radius - dist
                positions[i] -= overlap * normal * 0.5
                positions[j] += overlap * normal * 0.5

def update(frame):
    global positions, velocities 

    # 1) Compute accelerations (due to friction)
    acc = compute_accelerations(positions, velocities)
    
    # 2) Update velocities based on accelerations
    velocities += acc * dt
    
    # 3) Update positions based on velocities
    positions += velocities * dt
    
    # 4) Handle collisions with the table borders
    check_border_collision(positions, velocities)
    
    # 5) Handle ball-to-ball collisions
    check_ball_collision(positions, velocities)
    
    # Update circle positions
    for i, circle in enumerate(circles):
        circle.center = (positions[i, 0], positions[i, 1])
    
    # Calculate distance between ball[3] and ball[0]
    #dist_3_0 = np.linalg.norm(positions[3] - positions[0])
    starting_label.set_text(f"Initial speed: {initial_speed:.2f} Angle: {initial_angle:.2f}")
    
    return circles + [starting_label]  # Return the updated artists for blitting


In [None]:
n_simulations = 10
dt = 0.01         # Time step
num_steps = 5000   # Number of simulation steps
friction = 0.2    # Friction coefficient (affects ball-table interaction)
ball_radius = 0.2  # Radius of the balls
table_width = 6    # Width of the pool table
table_height = 4   # Height of the pool table
initial_speed = 10  # Initial speed of the white ball
center_of_target = [5,2]
center_of_start = [1,2]
# ---- Initialize ball properties ----
masses = np.array([1,1,1,1,1,1,1],dtype=float)

starting_positions = np.array([
        [center_of_target[0]+ball_radius+0.05, center_of_target[1]],
        [center_of_target[0]+ball_radius+0.05, center_of_target[1]+2*ball_radius+0.05],
        [center_of_target[0]+ball_radius+0.05, center_of_target[1]-2*ball_radius-0.05],
        [center_of_target[0]-ball_radius*(1-1/np.sqrt(2))-0.05,center_of_target[1]+ball_radius],
        [center_of_target[0]-ball_radius*(1-1/np.sqrt(2))-0.05,center_of_target[1]-ball_radius],
        [center_of_target[0]-(2*np.sqrt(2)-1)*ball_radius -0.1, center_of_target[1]],
        [center_of_start[0], center_of_start[1]]
    ], dtype=float)

starting_velocities = np.array([
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]
    ], dtype=float)

for k in range(n_simulations):

    # ---- Initialize ball properties ----
    masses = np.array([1,1,1,1,1,1,1],dtype=float)
    positions = starting_positions.copy()

    velocities = starting_velocities.copy()

    # Random angle for the initial shot (here set to 0 for testing)
    initial_speed = np.random.uniform(5, 15)  # Initial speed of the white ball
    initial_angle = np.random.uniform(-np.pi*0.01, np.pi*0.01)
    velocities[-1] = np.array([initial_speed * np.cos(initial_angle),
                            initial_speed * np.sin(initial_angle)], dtype=float)

    # --- Plot setup ---
    fig, ax = plt.subplots()
    ax.set_xlim(0, table_width)
    ax.set_ylim(0, table_height)
    ax.set_aspect('equal', adjustable='box')

    starting_label = ax.text(
        0.95, 0.95, "",
        transform=ax.transAxes,
        ha='right',  # align right
        va='top',    # align top
        fontsize=12,
        color='black'
    )

    # Create a Circle patch for each ball
    circles = []
    for i in range(len(positions)):
        color = 'white' if i == len(positions) - 1 else f'C{i}'
        circle = Circle(
            (positions[i, 0], positions[i, 1]),
            radius=ball_radius,
            facecolor=color,
            edgecolor='black'
        )
        ax.add_patch(circle)
        circles.append(circle)

    ani = FuncAnimation(fig, update, frames=2000, interval=20, blit=True)

    writer = FFMpegWriter(fps=30, metadata=dict(artist='Fedo', title='Pool Simulation'))
    plt.show()

    # Save the animation as an MP4 file
    ani.save(f'../data/pool_simulation_{k}.mp4', writer=writer, dpi=300)