In [324]:
import sympy
import itertools
import random
import numpy as np
sympy.init_printing()

In [325]:
gamma = 0.9
actions = ['u','d','l','r']
states = [(x,y) for x in range(5) for y in range(5)]

In [326]:
def bellman_optimality_eqs():
    vars = sympy.symbols([f"v{x}{y}" for x in range(5) for y in range(5)])
    eqs = [
        sympy.Eq(sympy.Symbol(f"v{x}{y}")
                ,sympy.Max(*[
                        dynamics(x,y,action)[2] + gamma * sympy.Symbol(f"v{dynamics(x,y,action)[0]}{dynamics(x,y,action)[1]}")
                        for action in actions
                ])) 
        for x in range(5) for y in range(5)
    ]
    return vars,eqs

def dynamics(x,y,action) -> tuple[int, int, float]:
    if y == 0 and x == 1:
        y = 4
        return (x,y, 10)
    if y == 0 and x == 3:
        y = 2
        return (x,y,10)
    match action:
        case 'u':
            if y > 0:
                y -= 1
                return (x,y,0)
            else:
                return (x,y,-1)
        case 'd':
            if y < 4:
                y += 1
                return (x,y,0)
            else:
                return (x,y,-1)
        case 'l':
            if x > 0:
                x -= 1
                return (x,y,0)
            else:
                return (x,y,-1)
        case 'r':
            if x < 4:
                x += 1
                return (x,y,0)
            else:
                return (x,y,-1)
    


In [327]:
# Brute force solution - Slow and wrong xdd
# 4^25 equations in the naive case :O


# vars, eqs = bellman_optimality_eqs()
# print(vars,eqs)
# solution = sympy.nsolve(eqs[:25], vars[:25], [5 for var in vars[:25]])


In [328]:
class Policy:
    def __init__(self):
        pass
    def policy(self, state) -> dict[str, float]:
        raise NotImplementedError
    def sample_action(self, state):
        #Not thread safe, for what it's worth
        policy = self.policy(state)
        return str(np.random.choice(list(policy.keys()),1,p=list(policy.values()))[0])
    
class UniformRandom(Policy):
    def __init__(self):
        super().__init__()
    def policy(self, state):
        return {action:0.25 for action in actions}
    
class Greedy(Policy):
    def __init__(self, state_values):
        super().__init__()
        self.state_values = state_values
    def policy(self, state):
        action_value_pairs = {action:self.state_values[dynamics(*state, action)[0]][dynamics(*state, action)[1]] for action in actions}
        print(action_value_pairs)
        selected_action = max(action_value_pairs, key=action_value_pairs.get)
        p = {action:0 for action in actions}
        p[selected_action] = 1
        return p

{'u': 0.25, 'd': 0.25, 'l': 0.25, 'r': 0.25}
['u', 'l', 'l', 'd', 'd', 'r', 'd', 'l', 'u', 'u', 'd', 'l', 'r', 'r', 'r', 'l', 'r', 'd', 'r', 'l']


In [None]:
def evaluate_policy():
    pass