## class State()

This class encodes a dictionary of all the balls in play, as well as dimensions of the table.

balls is a dict of ball objects that must be initialized in __init__, sorted by their ID. This could be a tuple, but a dict means we can easy update/modify a specific ball (such as the cue ball)

Question: how should we throw errors/check that errors have been thrown correctly? I saw Hayden used assert statements in methods in ball... 

In [28]:
%%writefile state/__init__.py

from __future__  import annotations
import dataclasses
from ball import Ball
import numpy as np
import copy


def find_collision_angle(p_a: list[float, float], p_b: list[float, float]) -> float:
    "calculates angle of tangent plane in between two colliding balls, in radians"
    "should return a value in [-3pi/2, pi/2]"
    
    angle_of_contact = np.arctan2((p_b[1]-p_a[1]),(p_b[0]-p_a[0]))
    return angle_of_contact - (np.pi/2)

        
def rotate_p_and_v(p:list[float, float], v:list[float, float], theta_radians:float):
    "performs simple rotation on velocity [magnitude, direction] and position [x,y]"
    "theta_radians is the collision angle, or the angle of the new frame w.r.t 0"
    
    v_theta_new = v[1] - theta_radians
    p_x = p[0]*np.cos(-theta_radians)-p[1]*np.sin(-theta_radians)
    p_y = p[0]*np.sin(-theta_radians)+p[1]*np.cos(-theta_radians)
    return [p_x, p_y], [v[0], v_theta_new]
        
def collision_confirmed(p_ay: float, p_by: float, v_ay: float, v_by: float) -> bool:
    "this is a simple helper function to take the rotated y positions and velocity"
    "of two colliding balls and confirm that they are going to hit each other"
    
    if (p_ay < p_by) and (v_ay < 0) and (v_by > 0):
        return False
    elif (p_ay > p_by) and (v_ay > 0) and (v_by < 0):
        return False
    elif (p_ay < p_by) and (v_ay < v_by < 0):
        return False
    elif (p_by < p_ay) and (v_by < v_ay < 0):
        return False
    elif (p_ay > p_by) and (v_ay > v_by > 0):
        return False
    elif (p_by > p_ay) and (v_by > v_ay > 0):
        return False
    else:
        return True
    
def post_collision_velocities(v_ax:float, v_ay:float, v_bx:float, v_by:float, collision_angle: float):
    "takes the rotated velocities of colliding balls in x and y components with respect to"
    "the collision plane (at angle = collision_angle), returns v of each ball in [mag, theta]"
    "returns a value between -pi/2 and 3pi/2"
    
    v_a_new = (v_ax**2 + v_by**2)**0.5
    v_b_new = (v_bx**2 + v_ay**2)**0.5
    theta_a_new = np.arctan2(v_by, v_ax) + collision_angle
    theta_b_new = np.arctan2(v_ay, v_bx) + collision_angle
    return [v_a_new, theta_a_new], [v_b_new, theta_b_new]

def update_one_step(balls:dict, dt:float, acc:float,  w:float, h:float) -> dict:
    "takes balls dict and updates every ball one time step, managing collisions"
    "does not check pockets or remove any balls from the dict"

    # run through every ball
    for ID in balls.keys():
        moving_ball = balls[ID]

        # only run on moving balls
        if moving_ball.v[0] == 0.0: continue
        
        # step ball forward (update position and velocity)
        moving_ball.time_step(dt, acc)
        if moving_ball.v[0] == 0.0: continue
        
        # check for collisions with all other balls and modify velocities if necessary
        for i in range(len(balls)):
            other_ID = list(balls.keys())[i]
            other = balls[other_ID]
            
            if (ID != other_ID) and moving_ball.collides_with(other):
                # pulling position and velocity out
                p_a = balls[ID].p
                p_b = balls[other_ID].p
                v_a = balls[ID].v
                v_b = balls[other_ID].v
                
                # performing rotation on positions and velocities
                collision_angle = find_collision_angle(p_a, p_b)
                p_a, v_a = rotate_p_and_v(p_a, v_a, collision_angle)
                p_b, v_b = rotate_p_and_v(p_b, v_b, collision_angle)
                
                # decomposing velocity in new frame of reference
                v_ax = v_a[0]*np.cos(v_a[1])
                v_ay = v_a[0]*np.sin(v_a[1])
                v_bx = v_b[0]*np.cos(v_b[1])
                v_by = v_b[0]*np.sin(v_b[1])
                # checking if the overlapping balls actually hit each other
                if collision_confirmed(p_a[1], p_b[1], v_ay, v_by):
                    
                    # change velocities of both balls
                    moving_ball.v, other.v = post_collision_velocities(v_ax, v_ay, v_bx, v_by, collision_angle)
                    
                    # add modified balls back into dict (this may do nothing)
                    balls[ID] = moving_ball
                    balls[other_ID] = other

        # check for table collisions and modify velocity if necessary
        if moving_ball.collides_with_table(w, h):
            v_x = moving_ball.v[0]*np.cos(moving_ball.v[1])
            v_y = moving_ball.v[0]*np.sin(moving_ball.v[1])

            if (moving_ball.p[0]+moving_ball.radius >= w/2) and (v_x > 0):
                v_x = -v_x
            elif (moving_ball.p[0]-moving_ball.radius <= -w/2) and (v_x < 0):
                v_x = -v_x
            elif (moving_ball.p[1]+moving_ball.radius >= h/2) and (v_y > 0):
                v_y = -v_y
            elif (moving_ball.p[1]-moving_ball.radius <= -h/2) and (v_y < 0):
                v_y = -v_y

            moving_ball.v[0] = (v_x**2 + v_y**2)**0.5
            moving_ball.v[1] = np.arctan2(v_y, v_x)
            balls[ID] = moving_ball
    
    # after updating every ball, return modified balls dict
    return balls

    
@dataclasses.dataclass
class State():
    balls: dict               # dictionary of all balls in play, keyed by ID
    log: list[dict]           # list of all state dicts
    pocketed: list            # list of ball IDs pocketed since the last turn, in order
    W_TABLE: float
    H_TABLE: float
    BALL_RADIUS: float
    DT: float
    ACCELERATION: float
    
    def __init__(self, initial_balls: dict[int: Ball]):
        "constructor takes dict of Ball objects with ball IDs as the keys"
        self.balls = initial_balls
        
        self.log = []
        self.pocketed = []
        self.W_TABLE = 1.27             # meters
        self.H_TABLE = 2.54             # meters
        self.BALL_RADIUS = 0.05715/2    # meters
        self.DT = 0.01                  # seconds
        self.ACCELERATION = 0.5         # m/s^2

    def update_old(self, velocity: float, degrees: float):
        "provides input to cue ball and manages interactions"
        "updates self.balls and self.pocketed_this_turn"
        
        balls = copy.deepcopy(self.balls)
        log = []
        
        # provide input to cue ball
        if 0 not in balls:
            balls[0] = Ball(0, self.BALL_RADIUS, 0, -0.635, 0, 0)
        balls[0].v = [velocity, np.radians(degrees)]
        
        # keep track of balls in motion
        balls_in_motion = []
        for i in range(len(balls)):
            if balls[list(balls.keys())[i]].v[0] != 0.0: balls_in_motion.append(list(balls.keys())[i])
        pocketed_this_turn = []
        
        while len(balls_in_motion) > 0:

            new_balls_in_motion = []

            # run through every moving ball
            for ID in balls_in_motion:
                moving_ball = balls[ID]
                
                # step ball forward (update position and velocity)
                moving_ball.time_step(self.DT, self.ACCELERATION)
                
                # check for collisions with all other balls and modify velocities if necessary
                for i in range(len(balls)):
                    other_ID = list(balls.keys())[i]
                    other = balls[other_ID]
                    
                    if (ID != other_ID) and moving_ball.collides_with(other):
                        # pulling position and velocity out
                        p_a = balls[ID].p
                        p_b = balls[other_ID].p
                        v_a = balls[ID].v
                        v_b = balls[other_ID].v
                        
                        # performing rotation on positions and velocities
                        collision_angle = find_collision_angle(p_a, p_b)
                        p_a, v_a = rotate_p_and_v(p_a, v_a, collision_angle)
                        p_b, v_b = rotate_p_and_v(p_b, v_b, collision_angle)
                        
                        # decomposing velocity in new frame of reference
                        v_ax = v_a[0]*np.cos(v_a[1])
                        v_ay = v_a[0]*np.sin(v_a[1])
                        v_bx = v_b[0]*np.cos(v_b[1])
                        v_by = v_b[0]*np.sin(v_b[1])
                        
                        # checking if the overlapping balls actually hit each other
                        if collision_confirmed(p_a[1], p_b[1], v_ay, v_by):
                            # change velocities of both balls
                            moving_ball.v, other.v = post_collision_velocities(v_ax, v_ay, v_bx, v_by, collision_angle)
                            
                            # add modified balls back into dict
                            balls[ID] = moving_ball
                            balls[other_ID] = other

                            # add other to new_balls_in_motion if necessary
                            if other_ID not in balls_in_motion and other_ID not in new_balls_in_motion:
                                new_balls_in_motion.append(other_ID)
                        
                        
                # check for table collisions and modify velocity if necessary
                if moving_ball.collides_with_table(self.W_TABLE, self.H_TABLE):
                    v_x = moving_ball.v[0]*np.cos(moving_ball.v[1])
                    v_y = moving_ball.v[0]*np.sin(moving_ball.v[1])

                    if (moving_ball.p[0]+moving_ball.radius >= self.W_TABLE/2) and (v_x > 0):
                        v_x = -v_x
                    elif (moving_ball.p[0]-moving_ball.radius <= -self.W_TABLE/2) and (v_x < 0):
                        v_x = -v_x
                    elif (moving_ball.p[1]+moving_ball.radius >= self.H_TABLE/2) and (v_y > 0):
                        v_y = -v_y
                    elif (moving_ball.p[1]-moving_ball.radius <= -self.H_TABLE/2) and (v_y < 0):
                        v_y = -v_y

                    moving_ball.v[0] = (v_x**2 + v_y**2)**0.5
                    moving_ball.v[1] = np.arctan2(v_y, v_x)
                    balls[ID] = moving_ball
                
                # check for pockets
                if moving_ball.in_pocket(self.W_TABLE, self.H_TABLE):
                    balls_in_motion.remove(ID)
                    balls.pop(ID)
                    pocketed_this_turn.append(ID)
                    
            # add any collided balls to balls_in_motion as necessary
            for ID in new_balls_in_motion:
                balls_in_motion.append(ID)

            # if any ball's velocity is sufficiently low, remove from balls_in_motion
            for ID in balls_in_motion:
                if balls[ID].v[0] <= 0.001:
                    balls[ID].v[0] = 0.0
                    balls_in_motion.remove(ID)
            
            # store this time frame in log
            log.append(copy.deepcopy(balls))

        # once while loop exits, log changes to balls
        self.balls = balls
        self.pocketed = pocketed_this_turn
        self.log = log

    def update(self, velocity: float, degrees: float):
        " new update function to run a loop of single updates with helper function update_one_step() "
    
        balls = copy.deepcopy(self.balls)
        log = []
        pocketed = []
        
        # provide input to cue ball
        if 0 not in balls:
            balls[0] = Ball(0, self.BALL_RADIUS, 0, -0.635, 0, 0)
        balls[0].v = [velocity, np.radians(degrees)]

        while True:

            # steps everything forward one step, modifies velocities, returns updated ball dict
            balls = update_one_step(balls, self.DT, self.ACCELERATION, self.W_TABLE, self.H_TABLE)

            # remove from balls if in pocket
            in_pocket = []
            for ID in balls.keys():
                if balls[ID].in_pocket(self.W_TABLE, self.H_TABLE): in_pocket.append(ID)
            for ID in in_pocket:
                balls.pop(ID)
                pocketed.append(ID)
            
            # store this time frame in log
            log.append(copy.deepcopy(balls))

            # check if all balls have stopped, break if so
            balls_in_motion = 0
            for ID in balls.keys():
                if balls[ID].v[0] != 0.0: balls_in_motion += 1
            if balls_in_motion == 0: break
            
        # once while loop exits, log changes to balls
        self.balls = balls
        self.log = log
        self.pocketed = pocketed


Overwriting state/__init__.py


# Tests
Still need to write all tests, I plan to further break up the update method into helper functions to make testing easier

It would probably be good parametrize testing to run on a variety of initial ball dicts

I'm including import statements in the test cell in case we want to export it to a test file

In [1]:
from ball import Ball
from state import *
from pytest import approx
from pytest import raises
import numpy as np

RADIUS = 0.05715/2
balls = {
    0: Ball(0,RADIUS,0,0,0,0),
    1: Ball(1,RADIUS,0.2,0,0,0),
    2: Ball(2,RADIUS,0.2,0.2,0,0),
    3: Ball(3,RADIUS,0,0.2,0,0),
    4: Ball(4,RADIUS,-0.2,0.2,0,0),
    5: Ball(5,RADIUS,-0.2,0,0,0),
    6: Ball(6,RADIUS,-0.2,-0.2,0,0),
    7: Ball(7,RADIUS,0,-0.2,0,0),
    8: Ball(8,RADIUS,0.2,-0.2,0,0)
}

def test_constructor():
    S = State(balls)
    assert S.balls == balls
    assert S.balls[0] == Ball(0,RADIUS,0,0,0,0)
    assert S.balls[0].p[0] == 0
    
def test_find_collision_angle():
    assert find_collision_angle(balls[0].p, balls[1].p) == approx(-np.pi/2)
    assert find_collision_angle(balls[0].p, balls[2].p) == approx(-np.pi/4)
    assert find_collision_angle(balls[0].p, balls[3].p) == approx(0)
    assert find_collision_angle(balls[0].p, balls[4].p) == approx(np.pi/4)
    assert find_collision_angle(balls[0].p, balls[5].p) == approx(np.pi/2)
    assert find_collision_angle(balls[0].p, balls[6].p) == approx(-5*np.pi/4)
    assert find_collision_angle(balls[0].p, balls[7].p) == approx(-np.pi)
    assert find_collision_angle(balls[0].p, balls[8].p) == approx(-3*np.pi/4)
    
def test_rotate_p_and_v():
    p_new, v_new = rotate_p_and_v([1,0], [5, np.pi/2], np.pi/4)
    assert p_new[0] == approx(np.sqrt(2)/2)
    assert p_new[1] == approx(-np.sqrt(2)/2)
    assert v_new[0] == 5
    assert v_new[1] == np.pi/4
    
    
def test_collision_confirmed():
    # balls headed away from each other
    assert collision_confirmed(1, -1, 1, -1) == False
    assert collision_confirmed(-1, 1, -1, 1) == False
    # one faster than other, headed up
    assert collision_confirmed(1, -1, 2, 1) == False
    assert collision_confirmed(-1, 1, 1, 2) == False
    # one faster than other, headed down
    assert collision_confirmed(1, -1, -1, -2) == False
    assert collision_confirmed(-1, 1, -2, -1) == False
    # true collisions
    assert collision_confirmed(1, -1, -1, 1) == True
    assert collision_confirmed(1, -1, -2, -1) == True
    assert collision_confirmed(1, -1, 1, 2) == True
    assert collision_confirmed(-1, 1, 1, -1) == True
    assert collision_confirmed(-1, 1, 2, 1) == True
    assert collision_confirmed(-1, 1, -1, -2) == True
    
def test_collision_confirmed_integrated():
    ball_a = Ball(1,RADIUS,0,-3,5,0)
    ball_b = Ball(2,RADIUS,2,-2,5,3*np.pi/2)
    
    collision_angle = find_collision_angle(ball_a.p, ball_b.p)
    p_a, v_a = rotate_p_and_v(ball_a.p, ball_a.v, collision_angle)
    p_b, v_b = rotate_p_and_v(ball_b.p, ball_b.v, collision_angle)
    v_ay = v_a[0]*np.sin(v_a[1])
    v_by = v_b[0]*np.sin(v_b[1])
                        
    assert collision_confirmed(p_a[1], p_b[1], v_ay, v_by) == True
    
def test_post_collision_velocities():
    v_ax = 4
    v_ay = 5
    v_bx = 12
    v_by = -3
    collision_angle = np.pi/4
    v_a1, v_b1 = post_collision_velocities(v_ax, v_ay, v_bx, v_by, collision_angle)
    theta_a1 = np.arctan2(-3, 4) + collision_angle
    theta_b1 = np.arctan2(5, 12) + collision_angle
    assert v_a1[0] == approx(5)
    assert v_b1[0] == approx(13)
    assert v_a1[1] == approx(theta_a1)
    assert v_b1[1] == approx(theta_b1)
    
    v_ax = np.sqrt(3)/2
    v_ay = np.sqrt(2)/2
    v_bx = -np.sqrt(2)/2
    v_by = -1/2
    v_a2, v_b2 = post_collision_velocities(v_ax, v_ay, v_bx, v_by, 0)
    v_a3, v_b3 = post_collision_velocities(v_ax, v_ay, v_bx, v_by, collision_angle)
    assert v_a2[1] == approx(-np.pi/6)
    assert v_b2[1] == approx(3*np.pi/4)
    assert v_a3[1] == approx(np.pi/12)
    assert v_b3[1] == approx(np.pi)
    
def test_update_reset_cue():
    balls_no_cue = {
        1: Ball(1,RADIUS,0.2,0,0,0),
        2: Ball(2,RADIUS,0.2,0.2,0,0)
    }
    S = State(balls_no_cue)
    S.update(0,0)
    assert 0 in S.balls.keys()
    
def test_update_simple():
    S = State(balls)
    S.update(1, 0)
    assert 0 in S.balls.keys()
    assert 1 not in S.balls.keys()
    assert S.pocketed == [1]

In [2]:
test_constructor()
test_find_collision_angle()
test_rotate_p_and_v()
test_collision_confirmed()
test_collision_confirmed_integrated()
test_post_collision_velocities()
test_update_reset_cue()
test_update_simple()

update num: 1
update num: 1
update num: 2
update num: 3
update num: 4
ball 0 stopped
update num: 5
update num: 6
update num: 7
update num: 8
update num: 9
update num: 10
update num: 11
update num: 12
update num: 13
update num: 14
ball 1 in pocket
