In [None]:
from typing import List,Tuple,TypeAlias
from enum import Enum, auto
import random

# Parameters: This line is a must. The grader parser uses this line to locate the Parameters cell.
GROUP_ID = 29
ALGORITHM = 'ValItr'  # ValItr | QLrng | SARSA. Note that “|” denotes a choice. Only one of the choices should be provided.
TRACK_NAME = 'tracks/U-track.txt'
CRASH_POS = 'NRST' # NRST | STRT


FAIL_RATE = 0.2

# region Definitions and Setup
Square: TypeAlias = Tuple[int, int]
Vector: TypeAlias = Tuple[int, int]

class SquareType(Enum):
    START = auto()       # starting square ('S')
    FINISH = auto()      # finish square ('F')
    OPEN = auto()        # open path ('.')
    WALL = auto()        # wall ('#')
    
    def __str__(self):
        return self.name

CHAR_TO_TOK = {
    'S':SquareType.START,
    'F':SquareType.FINISH,
    '.':SquareType.OPEN,
    '#':SquareType.WALL
}

TOK_TO_CHAR = {k:v for v,k in CHAR_TO_TOK.items()}

SQUARE_COST = {
    SquareType.START: 1,
    SquareType.OPEN: 1,
    SquareType.FINISH: 0,
    SquareType.WALL: None
}
# endregion

# region Track and Environment Classes
class Track:
    def __init__(self,filename=TRACK_NAME):
        self.state: List[List[str]] = []
        self.reward_values: List[List[float]] = []
        self.start_squares: List[Square] = []
        self.finish_squares: List[Square] = []

        self.parse_track(filename)

    def __str__(self):
        out = ""
        for row in self.state:
            out += ''.join([TOK_TO_CHAR[tok] for tok in row])
            out += '\n'
        return out[:-1]

    def parse_track(self,track):
        with open(track, 'r') as f:
            lines = f.readlines()
            for row,line in enumerate(lines[1:]):
                tok_line = []
                rew_line = []
                for col,char in enumerate(line):
                    if char=='\n': continue
                    
                    tok = CHAR_TO_TOK[char]
                    if tok is SquareType.START: self.start_squares.append((row,col))
                    if tok is SquareType.FINISH: self.finish_squares.append((row,col))
                    
                    tok_line.append(tok)
                    rew_line.append(SQUARE_COST[tok])
                
                self.state.append(tok_line)
                self.reward_values.append(rew_line)

    def get_square(self,square: Square) -> SquareType:
        return self.state[square[0]][square[1]]

    def get_start_squares(self) -> List[Square]:
        return self.start_squares

    def is_square_finish(self, square: Square) -> bool:
        return self.get_square(square) is SquareType.FINISH


    def is_square_drivable(self,square: Square) -> bool:
        if not all(0 <= square[i] < len(self.state) for i in [0,1]):
            return False
        return not self.get_square(square) is SquareType.WALL

class RaceTrackEnv():
    def __init__(self, track: None|Track = None,starting_square: Square = None):
        self.track:        Track = track or Track()
        self.position:     Square = starting_square or self.track.start_squares[0]
        self.velocity:     Vector = (0,0)
        self.acceleration: Vector = (0,0)

    def stop(self):
        self.acceleration = self.velocity = (0,0)

    def reset(self, position: Square):
        self.stop()
        self.position = position

    @staticmethod
    def bresenham_line(pos1: Square, pos2: Square) -> List[Square]:
            """Generate all points along a line using Bresenham's algorithm"""
            
            points = []

            x0, y0 = pos1
            x1, y1 = pos2
            dx, dy = abs(x1 - x0), abs(y1 - y0)
            sx = 1 if x0 < x1 else -1
            sy = 1 if y0 < y1 else -1
            err = dx - dy
            
            x, y = x0, y0
            while True:
                points.append((x, y))
                if x == x1 and y == y1:
                    break
                e2 = 2 * err
                if e2 > -dy:
                    err -= dy
                    x += sx
                if e2 < dx:
                    err += dx
                    y += sy
            
            return points

    def check_crash(self,target_square: Square) -> bool:
        """
        Check if moving along a line from current position to target crashes into an obstacle.
        Uses Bresenham's line algorithm to trace the path.
        """

        # Get all points along the path
        path_points = self.bresenham_line(
            self.position, target_square
        )
        # Check each point for collision
        for sq in path_points:
            if not self.track.is_square_drivable(sq):
                return True  # Crash detected
        return False  # No crash

    def do_crash(self,crash_position: str):
        """
        Handles the crash based on the crash_position policy.
        crash_position: 'NRST' | 'STRT'
        1. 'NRST': Move to the nearest start square.
        2. 'STRT': Move to the starting square used at the beginning of the race
        """
        if crash_position == 'NRST':
            nearest_start = min(self.track.start_squares, key=lambda sq: (sq[0]-self.position[0])**2 + (sq[1]-self.position[1])**2)
            self.reset(nearest_start)
        elif crash_position == 'STRT':
            self.reset(self.track.start_squares[0])
        else:
            raise ValueError(f"Invalid crash_position policy: {crash_position}")

    def check_failiure(self,fail_rate: float) -> bool:
        """
        Returns True if the action fails based on the fail_rate.
        """
        if random.random() < fail_rate:
            return True
        return False
    
    def check_finish(self) -> bool:
        return self.track.is_square_finish(self.position)

    def step(self,acceleration: Vector,fail_rate=FAIL_RATE,crash_position=CRASH_POS):
        """
        Perform a step in the environment given an acceleration.
        `Note velocity values are capped to [-5,5]`
        acceleration: Tuple[int,int] where each value is in [-1,0,1]
        fail_rate: Probability of action failure.
        crash_position: 'NRST' | 'STRT' policy for handling crashes.
        Returns: None
        """
        if not all(a in [-1,0,1] for a in acceleration):
            raise ValueError(f"Invalid acceleration: {acceleration}")
        
        doAccel = True
        if self.check_failiure(fail_rate): doAccel = False
        if doAccel:
            self.acceleration = acceleration
        
        self.velocity = (self.velocity[0]+self.acceleration[0],self.velocity[1]+self.acceleration[1])
        self.velocity = tuple(min(5,max(-5,val)) for val in self.velocity) # Caps velocity to [-5,5]
        target_position = (self.position[0]+self.velocity[0],self.position[1]+self.velocity[1])
        
        if self.check_crash(target_position):
            self.do_crash(crash_position)
        else:
            self.position = target_position

# endregion

# region Model Based
class MDPModel():
    def __init__(self):
        transitions = {}
        rewards = {}

    def get_possible_actions(self,state):
        pass

    def get_transition_states_and_probs(self,state,action):
        pass

    def get_reward(self,state,action,next_state):
        pass

class ValueIterationAgent():
    def __init__(self):
        value_table = {}
        policy = {}
        model = None
        gamma = 0
        theta = 0

    def value_iteration(self):
        pass

    def get_action_for(self,state):
        pass
# endregion

# region Model Free
class SARSAAgent():
    def __init__(self):
        qtable = {}
        alpha = 0
        gamma = 0
        epsilon = 0
        last_state = None
        last_action = None

    def act(self,state):
        pass

    def update(self,state,action,reward,next_state,next_action):
        pass

    def best_action(self,state):
        pass

class QLearningAgent():
    def __init__(self):
        qtable = {}
        alpha = 0
        gamma = 0
        epsilon = 0

    def act(self,state):
        pass

    def update(self,state,action,reward,next_state):
        pass

    def best_action(self,state):
        pass
# endregion

# region Output and Metrics
class EpisodeRunner():
    def __init__(self):
        env = None
        agent = None
        max_steps = 0

    def run_episode(self,algorithm: SARSAAgent|QLearningAgent):
        pass

class MetricsLogger():
    def __init__(self):
        episodes: List[int] = []
        steps: List[int] = []
        rewards: List[float] = []

    def log_episode(self,episode:int,steps:int,reward:float):
        pass

    def print_metrics(self):
        pass

    def plot_metrics(self):
        pass
# endregion

def main():
    print("Racecar MDP Simulation")
    # Initialize environment, agent, and runner here
    # Run episodes and log metrics
    pass

if __name__ == "__main__":
    main()








#########################
#SSSS###############FFFF#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
#....###############....#
##....#############....##
###....###########....###
###...................###
####.................####
#####...............#####
######.............######
#########################
