In [None]:
import numpy as np
import matplotlib.pyplot as plt

# global variables
BOARD_ROWS = 5
BOARD_COLS = 5

WIN_STATE = (4, 4)  # Adjusted to 0-based indexing
JUMP_STATE = (1, 3)  # Adjusted to 0-based indexing
START = (1, 0)  # Adjusted to 0-based indexing
OBSTACLES = [(2, 2), (2, 3), (2, 4), (3, 2)]


DETERMINISTIC = True
token = ''


class State:
    def __init__(self, state=START):
        # Initialize the state of the grid

        self.board = np.zeros([BOARD_ROWS, BOARD_COLS])
        self.board[1, 1] = -1
        self.state = state
        self.isEnd = False
        self.determine = DETERMINISTIC  # Determine if actions are deterministic or stochastic

    def give_reward(self):
        # Return the reward based on the current state

        if self.state == WIN_STATE:
            return +10
        elif self.state == JUMP_STATE:
            return +5
        else:
            return -1

    def is_end_func(self):
        # Check if the current state is an end state

        if self.state == WIN_STATE:
            self.isEnd = True

    def nxt_position(self, action):
        # Determine the next position based on the action taken

        """
        action: up, down, left, right, jump
        -------------
        0 | 1 | 2| 3| 4|
        1 |
        2 |
        3 |
        4 |
        return next position
        """
        # print("testing self.state", self.state)

        if self.determine:
            if action == "up":
                nxt_state = (self.state[0] - 1, self.state[1])
            elif action == "down":
                nxt_state = (self.state[0] + 1, self.state[1])
            elif action == "left":
                nxt_state = (self.state[0], self.state[1] - 1)
            elif action == "jump":
                nxt_state = (self.state[0] + 2, self.state[1])
            else:
                # right state
                nxt_state = (self.state[0], self.state[1] + 1)

            # Check if the next state is legal

            if (nxt_state[0] >= 0) and (nxt_state[0] <= (BOARD_ROWS - 1)):  # Check row bounds for 5x5 grid
                if (nxt_state[1] >= 0) and (nxt_state[1] <= (BOARD_COLS - 1)):  # Check column bounds for 5x5 grid
                    if nxt_state not in OBSTACLES:  # Exclude specified positions
                        return nxt_state
                    
            # Only if all conditions are met does the move succeed; otherwise, the agent does not move.
            return self.state

    def show_board(self):
        # Display the current state of the grid

        global token
        self.board[self.state] = 1
        for i in range(0, BOARD_ROWS):
            print('-----------------')
            out = '| '
            for j in range(0, BOARD_COLS):
                if self.board[i, j] == 1:
                    token = '*'
                if self.board[i, j] == -1:
                    token = 'z'
                if self.board[i, j] == 0:
                    token = '0'
                out += token + ' | '
            print(out)
        print('-----------------')


# Agent of player

class Agent:

    def __init__(self):

        # Initialize the agent with states, actions, and state values
        self.states = []
        self.actions = ["up", "down", "left", "right", "jump"]  # Possible actions for the agent
        self.State = State()
        self.lr = 0.2  # Learning rate for updating state values
        self.exp_rate = 0.3  # Exploration rate for epsilon-greedy exploration

        # Initialize state values (Q-values)
        self.state_values = {}
        for i in range(BOARD_ROWS):
            for j in range(BOARD_COLS):
                self.state_values[(i, j)] = 0  # set initial value to 0

# Help me understand line by line in simple terms what this function below does in the overall code
    def choose_action(self):

        # Choose an action based on the current state and exploration strategy
        mx_nxt_reward = 0
        action = ""

        if np.random.uniform(0, 1) <= self.exp_rate:
            action = np.random.choice(self.actions)
            while action == 'jump' and self.State.state != (1, 3):
                action = np.random.choice(self.actions)

        else:
            # Greedy action selection
            for a in self.actions:
                # Only consider the next state if the action is valid from current state
                if self.State.nxt_position(
                        a) != self.State.state:  # Check if action is valid (doesn't stay in the same place)
                    nxt_reward = self.state_values[self.State.nxt_position(a)]
                    if a == "jump" and self.State.state != (1, 3):  # Skip jump if not at JUMP_STATE
                        continue  # Move on to the next action in the loop
                    if nxt_reward >= mx_nxt_reward:
                        action = a
                        mx_nxt_reward = nxt_reward

        return action

    def take_action(self, action):
        # Take action and transition to the next state

        position = self.State.nxt_position(action)
        return State(state=position)

    def reset(self):
        # Reset the agent's states for a new episode

        self.states = []
        self.State = State()

    def play(self, rounds=100):
        # Train the agent by playing multiple episodes

        i = 0
        while i < rounds:
            # to the end of game back propagate reward
            if self.State.isEnd:
                # back propagate
                reward = self.State.give_reward()

                # explicitly assign end state to reward values
                self.state_values[self.State.state] = reward  # this is optional
                print("Game End Reward", reward)

                for s in reversed(self.states):
                    reward = self.state_values[s] + self.lr * (reward - self.state_values[s])
                    # print('reward testing', reward)
                    self.state_values[s] = round(reward, 4)
                    # print('state values', self.states)
                    # print(self.state_values)
                self.reset()
                i += 1
            else:
                action = self.choose_action()
                # append trace
                self.states.append(self.State.nxt_position(action))
                print("current position {} action {}".format(self.State.state, action))
                # by taking the action, it reaches the next state
                self.State = self.take_action(action)
                # mark is end
                self.State.is_end_func()
                print("nxt state", self.State.state)
                print("------------------------------------")

    def show_values(self):
        # Display the state values

        for i in range(0, BOARD_ROWS):
            print('----------------------------------------------')
            out = '| '
            for j in range(0, BOARD_COLS):
                out += str(self.state_values[(i, j)]).ljust(6) + ' | '
            print(out)
        print('----------------------------------------------')

    def plot_state_values(self):
        # Plot heatmap for state values

        q_values = np.zeros((BOARD_ROWS, BOARD_COLS))
        for i in range(BOARD_ROWS):
            for j in range(BOARD_COLS):
                q_values[i, j] = max(self.state_values[(i, j)] for a in self.actions)
        plt.imshow(q_values, cmap='viridis', origin='upper')
        plt.colorbar()
        plt.title('Maximum Q-value for each state')
        plt.show()

    def plot_state_values_2(self):
        # Plot grid cells for state values

        for i in range(BOARD_ROWS):
            for j in range(BOARD_COLS):
                plt.text(j, i, str(self.state_values[(i, j)]), ha='center', va='center', fontsize=12)
        plt.gca().set_yticks(np.arange(-0.5, BOARD_ROWS, 1), minor=True)
        plt.gca().set_xticks(np.arange(-0.5, BOARD_COLS, 1), minor=True)
        plt.grid(which='minor', color='black', linestyle='-', linewidth=1)
        plt.gca().invert_yaxis()
        plt.title('Accumulated State Values')
        plt.show()


if __name__ == "__main__":
    ag = Agent()
    ag.play(100)
    print(ag.show_values())
    ag.plot_state_values_2()
    ag.plot_state_values()
