# HW 06

Implement Value Iteration on MDP

In [None]:
import numpy as np
from typing import List
from typing import Tuple
import copy
import matplotlib.pyplot as plt
import numpy as np


### DEFINITIONS

class State:
    """A state (x and y coordinate)"""
    
    def __init__(self, x: int, y: int):
        assert x < GRID_DIM_X and x >= 0, "x out of bounds"
        assert y < GRID_DIM_Y and y >= 0, "y out of bounds"
        self.coord_x = x
        self.coord_y = y
        return
    
    def __eq__(self, other):
        # Check if 'other' is a state
        if isinstance(other, State):
            return self.coord_x == other.coord_x \
                and self.coord_y == other.coord_y
        return False
    
    def __str__(self):
        return f"State(x={self.coord_x}, y={self.coord_y})"

    def __repr__(self):
        return self.__str__()
    
class Action:
    """an action, simply encoded as an integer"""
    # 0 ... stay
    # 1 ... up
    # 2 ... right
    # 3 ... down
    # 4 ... left

    def __init__(self,action: int):
        assert action < 5 and action >= 0, "action out of bounds"
        self.action = action
        return
    
    def __eq__(self, other):
        # Check if 'other' is a state
        if isinstance(other, Action):
            return self.action == other.action
        return False
    
class Policy:
    """Class to represent a policy"""
    
    def __init__(self):
        self._action = np.zeros((GRID_DIM_X,GRID_DIM_Y)) #access as value_function[state.coord_x, state.coord_y]
        return

    def set_value(self, state: State, action: Action):
        self._action[state.coord_x, state.coord_y] = action.action
        return
    def get_value(self, state: State) -> Action:
        return Action(self._action[state.coord_x, state.coord_y])
    
    def print_arrow_table(self):
        """print the policy as ASCII art, generated by ChatGPT"""
        print("Policy")

        matrix = self._action
        # Define the symbols for each value
        symbols = {
            0: "•",      # stay
            1: "↑",      # up
            2: "→",      # right
            3: "↓",      # down
            4: "←"       # left
        }

        # Determine the width of each cell for alignment
        cell_width = max(len(symbols[v]) for row in matrix for v in row) + 2

        # Print the table with ASCII box drawing characters
        top_border = "┌" + "┬".join("─" * cell_width for _ in matrix[0]) + "┐"
        row_separator = "├" + "┼".join("─" * cell_width for _ in matrix[0]) + "┤"
        bottom_border = "└" + "┴".join("─" * cell_width for _ in matrix[0]) + "┘"

        # Print top border
        print(top_border)
        
        for i, row in enumerate(matrix):
            # Print row with arrows
            print("│" + "│".join(f"{symbols[val]:^{cell_width}}" for val in row) + "│")
            
            # Print row separator, except after the last row
            if i < len(matrix) - 1:
                print(row_separator)
        
        # Print bottom border
        print(bottom_border)
    
class ValueFunction:
    """Value Function"""

    def __init__(self):
        self._storage = np.zeros((GRID_DIM_X,GRID_DIM_Y)) #access as value_function[state.coord_x, state.coord_y]

    def get_value(self, state: State) -> float:
        return self._storage[state.coord_x, state.coord_y]
    def set_value(self, state: State, value: float):
        self._storage[state.coord_x, state.coord_y] = value

    def plot(self):
        """returns a figure and axes to plot the value function as a heat map"""
        fig, ax = plt.subplots()
        cax = ax.imshow(self._storage, cmap='viridis', interpolation='nearest')
        fig.colorbar(cax, ax=ax, label="Value")
        ax.set_title("Value Function Heatmap")
        return fig, ax
    
    def extract_policy(self,
                    set_states: List[State],
                    set_actions: List[Action]) -> Policy:
        policy = Policy()
        value_function_discard = ValueFunction()
        for state in set_states:
            # just greedily take actin that gives the highest value
            action, val = step_vi(state_evaluate=state,
                                  set_states=set_states,
                                  set_actions=set_actions,
                                  value_function_new=value_function_discard,
                                  value_function_old=self)
            policy.set_value(state,action)

        return policy
    
    def extract_qfunction(self,
                          set_states: List[State],
                          set_actions: List[Action]) -> np.ndarray:
        """extracts a qfunction as a matrix [s.x,s.y,a]"""
        qmatrix = np.ndarray((GRID_DIM_X,GRID_DIM_Y,5))
        for state in set_states:
            for action in set_actions:
                val = substep_vi(state_evaluate=state,
                                         action=action,
                                         set_states=set_states,
                                         value_function_old=self)
                qmatrix[int(state.coord_x),int(state.coord_y),int(action.action)] = val
        return qmatrix

def print_qfunction(qmatrix: np.ndarray):
    """print a qfunction that is represented as a matrix
    
    prettied up with ChatGPT"""
    assert qmatrix.shape == (GRID_DIM_X,GRID_DIM_Y,5)\
        , "Dude, check what you feed your functions!"

    # Mapping of actions to symbols
    action_symbols = {
        0: "•",  # stay
        1: "↑",  # up
        2: "→",  # right
        3: "↓",  # down
        4: "←"   # left
    }

    print("Q Function Table")
    
    # Print the header
    print(f"{'X':<5} {'Y':<5} {'Action':<10} {'Q-Value':<10}")
    print("-" * 30)  # Separator line

    # Iterate through each coordinate and action
    for x in range(GRID_DIM_X):
        for y in range(GRID_DIM_Y):
            for a in range(5):
                # Print each row in the table format with action symbols
                print(f"{x:<5} {y:<5} {action_symbols[a]:<10} {qmatrix[x, y, a]:<10.4f}")

def reward_funciton(state: State):
    """reward for being at a certain state"""
    if(TARGET_STATE == state):
        return 10
    return -1

def transition_probability(state_src: State, state_dst: State, action: Action):
    """if being in state src and applying action, what is the probability to end up at state dst
    
    note that illegal states do not need to be checked, as dst needs to be valid state!"""
    # action = stay
    if state_dst == state_src\
        and action.action == 0:
        return 1
    # action is not stay, but still stay
    if state_dst == state_src\
        and action.action != 0:
        return 1 - TRANSITION_PROB_P
    # desired actions
    if (state_dst.coord_x == state_src.coord_x\
        and state_dst.coord_y == state_src.coord_y+1\
        and action.action == 2)\
        or\
        (state_dst.coord_x == state_src.coord_x\
        and state_dst.coord_y == state_src.coord_y-1\
        and action.action == 4)\
        or\
        (state_dst.coord_x == state_src.coord_x-1\
        and state_dst.coord_y == state_src.coord_y\
        and action.action == 1)\
        or\
        (state_dst.coord_x == state_src.coord_x+1\
        and state_dst.coord_y == state_src.coord_y\
        and action.action == 3):
        return TRANSITION_PROB_P
    # everything else
    return 0

def substep_vi(action: Action,
               state_evaluate: State,
               set_states: List[State],
               value_function_old: ValueFunction) -> float:
    """Substep in VI Step: Calculate the value for an action

    (extracted as substep to calculate Q function efficiently)"""
    sum_nextstep = 0
    for state_next in set_states:
        sum_nextstep += transition_probability(state_src=state_evaluate,
                                                state_dst=state_next,
                                                action=action)\
                        *value_function_old.get_value(state_next)
    val_action = reward_funciton(state_evaluate)\
        + DISCOUNT*sum_nextstep
    return val_action

def step_vi(state_evaluate: State #where evaluating
            , set_states: List[State] #set of all states
            , set_actions: List[Action] #set of all actions
            , value_function_old: ValueFunction #old value function
            , value_function_new: ValueFunction) -> Tuple[Action, float]: #new value function
    """the maximization step with a deterministic policy in VI
    
    returns the action and the step that led to that maximum"""
    maximum = -np.inf
    action_maximum = None
    for action in set_actions:
        val_action = substep_vi(action=action
                                ,state_evaluate=state_evaluate
                                ,set_states=set_states
                                ,value_function_old=value_function_old)
        if val_action > maximum:
            maximum = val_action
            action_maximum = action
    # print(f"Updating {state_evaluate} from {value_function_old.get_value(state_evaluate)} -> {maximum}")
    value_function_new.set_value(state_evaluate, maximum)

    return (action_maximum ,maximum)

def value_iteration(value_function_init: ValueFunction, #initial value function
                    set_states: List[State], #all possible states
                    set_actions: List[Action], #all possible actions
                    epsilon: float = 0.42 #when stopping to converge
                    ) -> ValueFunction:
    """runs VI and returns a (new) value function
    
    the algorithm terminates if the change in the infinty norm of two consequitive value
    functions is < epsilon"""
    step = 0
    value_function_old = value_function_init
    value_function_new = copy.deepcopy(value_function_init)
    while True:
        norm = -np.inf
        for state in set_states:
            step_vi(state_evaluate=state,
                    set_states=set_states,set_actions=set_actions,
                    value_function_old=value_function_old,
                    value_function_new=value_function_new)
            #update norm
            norm_state = np.abs(value_function_new.get_value(state)
                                - value_function_old.get_value(state))
            if norm_state > norm:
                norm = norm_state
        # update value funciton
        value_function_old = copy.deepcopy(value_function_new)
        # check break criterion
        print(f"[{step}]\tLIninity Norm: {norm}")
        if norm < epsilon:
            break
        # update step
        step += 1
    # return found value function
    print(f"Value Interation converged after {step} steps.")
    return value_function_old          
            


### DEFINITIONS

GRID_DIM_X = 4
GRID_DIM_Y = 4

TARGET_STATE = State(x=3,y=3)

DISCOUNT = 0.95

TRANSITION_PROB_P = 0.7 #probability that the transition actually works out

### SCRIPT

set_states = [] #set of all states
for x in range(0,GRID_DIM_X):
    for y in range(0,GRID_DIM_Y):
        set_states.append(State(x=x,y=y))
set_actions = []
for a in range(0,5):
    set_actions.append(Action(a))

value_function = ValueFunction()

# run VI
found_value_function = value_iteration(value_function_init=value_function
                            ,set_states=set_states
                            ,set_actions=set_actions
                            ,epsilon=0.00001337)

# plot Value Function
fig, ax = found_value_function.plot()
plt.show()
fig.savefig("./out/figure_hw06.png",format='png', dpi=300)

# extract policy
optimal_policy = found_value_function.extract_policy(
    set_states=set_states, set_actions=set_actions)

# print policy
optimal_policy.print_arrow_table()

# qfunction
qfunction = found_value_function.extract_qfunction(
    set_actions=set_actions, set_states=set_states)
print_qfunction(qfunction)

print("DONE")