In [None]:
"""The lunar lander lab uses the Actor Environment formalism, where
the actor takes an action and the environment evaluates the action
to get a new observation state and reward. These files are my attempt
at constructing a readable and extensible codebase around the lander
and formalism. I have not actually implemented the lab, but the point
of this exercsie was to think of and explore ways to productionalize
code for data science.

The code is a bit over-abstracted on purpose. To provide flexibility and
to test the abstractions which I may want to use. 
"""
from enum import Enum
from dataclasses import dataclass
from typing import Collection, Callable, Protocol, TypeAlias

class Action(Enum):
    do_nothing = 0
    fire_main_engine = 1
    fire_left_engine = 2
    fire_right_engine = 3



@dataclass
class State:
    """Observation state of the lunar lander"""
    x: float = 0
    y: float = 0
    x_velocity: float = 0
    y_velocity: float = 0
    angle: float = 0
    angular_velocity: float = 0
    left_leg_contact: bool = False
    right_leg_contact: bool = False

    def step(self) -> None:
        """Step the state forward in time"""
        self.y += self.y_velocity
        self.x += self.x_velocity
        self.angle += self.angular_velocity

# This is the ideal state we want to reach
# However we could still be successful if we land but at an angle
# or have some residual velocity...
desired_state = State(
    x=0,
    y=0,
    x_velocity=0,
    y_velocity=0,
    angle=0,
    angular_velocity=0,
    left_leg_contact=True,
    right_leg_contact=True
)

SurfaceFunction: TypeAlias = Callable[[float], float]

def flat_surface(x: float) -> float:
    """A flat surface function"""
    return 0.2

class BoundaryStates(Enum):
    """States that are considered boundary conditions or the default"""
    flying = 0 # default state
    landed = 0
    crashed = 1
    left_screen = 2

class BoundsCheck(Protocol):
    """Protocol for determining failure or success states,
    which can be thought of as boundary conditions on the state space."""
    def __call__(self, state: State) -> BoundaryStates:
        ...

@dataclass
class MoonBounds(BoundsCheck):
    """Bounds of the moon"""
    surface_func: SurfaceFunction = flat_surface
    desired_state: State = desired_state

    def __call__(self, state: State) -> BoundaryStates:
        """Get the boundary condition for the current state"""
        if self.crashed(state):
            return BoundaryStates.crashed
        if self.left_screen(state):
            return BoundaryStates.left_screen
        if self.landed(state):
            return BoundaryStates.landed
        return BoundaryStates.flying
    
    def landed(self, state: State) -> bool:
        """Whether we (safely) landed. Unsafe landing is when we land
        at too much of an angle and/or with too much velocity. Note:
        Currently this is treated just like not landing at all."""
        desired_state = self.desired_state
        current_state = state
        return current_state.x == desired_state.x and \
            current_state.y == desired_state.y and \
            current_state.left_leg_contact == desired_state.left_leg_contact and \
            current_state.right_leg_contact == desired_state.right_leg_contact and \
            current_state.x_velocity <= desired_state.x_velocity and \
            current_state.y_velocity <= desired_state.y_velocity and \
            current_state.angle <= abs(desired_state.angle) and \
            current_state.angular_velocity <= desired_state.angular_velocity

    
    def crashed(self, state: State) -> bool:
        """Whether we crashed. We crash if we hit the moon surface.
        The surface is defined by a function that takes the x coordinate
        and returns the y coordinate of the surface."""
        return state.y <= self.surface_func(state.x)
    
    def left_screen(self, state: State) -> bool:
        """Whether we are still in bounds. We are out of bounds if we
        are outside of the x bounds of the screen."""
        return 0 <= state.x <= 1
    


class RewardAssignment(Protocol):
    """Protocol for assigning rewards to states. Allowing for different
    reward functions both for different states and for boundary conditions
    (landed, crashed, left screen, etc.)"""
    def __call__(self, state: State, boundary_state: BoundaryStates) -> float:
        ...
    
@dataclass
class Reward:
    """Since the reward function is coupled to the boundary conditions
    and state, define a class that takes in state and bounds, and provides
    a callable as the overall reward function, implementing specifics as
    needed."""
    observation_state_reward: RewardAssignment
    collision_penalty: float = -100
    screen_penalty: float = -100
    done_reward: float = 100

    def __call__(self, state: State, boundary_state: BoundaryStates = BoundaryStates.flying) -> float:
        """Get the reward for the current state"""

        # Assuming failure boundary condition rewards invalidate other
        # state dependent rewards.
        if boundary_state is BoundaryStates.crashed:
            return self.collision_penalty
        if boundary_state is BoundaryStates.left_screen:
            return self.screen_penalty
        reward: float = 0
        if boundary_state is BoundaryStates.landed:
            reward += self.done_reward

        # @TODO: implement flying observation state dependent rewards
        reward += self.observation_state_reward(state, boundary_state)
        return reward 
    
@dataclass
class EngineActions:
    """State of the actions being taken for the engines, (here
    we make no assumption of one action at a time.)"""
    main: bool = False
    left: bool = False
    right: bool = False
    
    def get_actions(self) -> set[Action]:
        """Get the actions that are currently being taken, we use
        set since order must not matter."""
        actions: set[Action] = set()
        if self.main:
            actions.add(Action.fire_main_engine)
        if self.left:
            actions.add(Action.fire_left_engine)
        if self.right:
            actions.add(Action.fire_right_engine)
        if not actions: # if we are not doing anything...
            actions.add(Action.do_nothing)
        return actions
    

class Policy(Protocol):
    """Protocol for defining policies"""
    def __call__(self, state: State) -> Action:
        ...

class StateAction(Protocol):
    """Protocol for defining state modification based on action"""
    def __call__(self, state: State, action: Action) -> State:
        ...



def modify_state_with_action(state: State, action: Action) -> State:
    """Modify the state with the given action (in place modification)"""
    match action:
        case Action.do_nothing:
            pass
        case Action.fire_main_engine:
            state.y_velocity += 0.1
        case Action.fire_left_engine:
            state.x_velocity -= 0.05
            state.angular_velocity -= 0.05
        case Action.fire_right_engine:
            state.x_velocity += 0.05
            state.angular_velocity += 0.05
    gravity = -0.00  # Assuming negligible gravity
    state.y_velocity += gravity
    state.step()
    return state



@dataclass
class Agent:
    """We've leaked the abstraction a bit, as we communicate over state
    instead of actions."""
    previous_action: Action = Action.do_nothing
    current_state: State = State()
    policy: str = "SimplePolicy" # TODO: implement policy
    state_action: StateAction = modify_state_with_action

    def take_action(self, action: Action) -> State:
        """Take an action in the environment"""
        # Technically we modify in place but we are being explicit
        # especially good if we change to copy on write
        self.current_state = self.state_action(self.current_state, action)
        self.previous_action = action
        return self.current_state


@dataclass
class Environment:
    agent: Agent
    bounds: BoundsCheck
    reward: RewardAssignment #= Reward()
    surface: SurfaceFunction = flat_surface
    current_boundary_state: BoundaryStates = BoundaryStates.flying
    done_boundary_state: BoundaryStates = BoundaryStates.landed
    
    def step(self, action: Action) -> tuple[State, float, bool]:
        """Take a step in the environment"""
        # This is a leaky abstraction; we are communicating over state.
        # We should instead communicate over actions and determine the
        # state based on the action reported by the agent and its previous
        # state!
        unresolved_state = self.agent.take_action(action)

        # Evaluate leg contact
        if (unresolved_state.y - self.surface(unresolved_state.x)) <= 0.01:
            unresolved_state.left_leg_contact = True
            unresolved_state.right_leg_contact = True

        # Update state and boundary state.
        current_state = unresolved_state
        self.current_boundary_state = self.bounds(current_state)

        reward = self.reward(current_state, self.current_boundary_state)
        done = self.bounds(current_state) is self.done_boundary_state
        return current_state, reward, done
    



agent = "lander"
