In [59]:
import numpy as np
import gymnasium as gym
import rlenvs
from functools import reduce
# from scipy.special import softmax
# import pandas as pd

from rlcd.model import *

cart_velocity_params =  {"step": 2, "size": 20, "scale": .1}
pole_angle_params =  {"step": 45, "size": 20, "scale": .0001}
angular_velocity_params =  {"step": 192, "size": 4, "scale": .001}

In [67]:
def discretize(s):
    _, cart_velocity, pole_angle, angular_velocity = s # ignore cart_position
    def func(val, step, size, scale):
        limits = [ i*scale for i in range(-(size-2)//2*step, (size+1)//2*step, step)]
        ids = [i for i,v in enumerate(limits) if val<v]
        return size-1 if len(ids) == 0 else ids[0]
    return (
            func(cart_velocity, **cart_velocity_params),
            func(pole_angle, **pole_angle_params),
            func(angular_velocity, **angular_velocity_params)
    )

def enumerate_state(factored_state, factor_sizes):
    s = np.meshgrid(*[np.arange(f) for f in factor_sizes])
    grid = np.vstack([si.ravel() for si in s])
    return int(np.argwhere([p==factored_state for p in zip(*grid)])[0][0])

def factor_state(enum_state, factor_sizes):
    s = np.meshgrid(*[np.arange(f) for f in factor_sizes])
    grid = np.vstack([si.ravel() for si in s])
    return tuple(int(i) for i in grid[:,enum_state])

s = (4,10,2)
enum = enumerate_state(s, (20,20,4))
print(enum)
print(factor_state(enum, (20,20,4)))

818
(4, 10, 2)


In [84]:
# env = gym.make("custom/CartPole-v1", render_mode="human")
env = gym.make("custom/CartPole-v1")

sizes = [cart_velocity_params['size'], pole_angle_params['size'], angular_velocity_params['size']]
model = RLCD(np.array([s for s in range(reduce(lambda a,b: a*b, sizes))]).reshape((1, *sizes)), np.array([0,1]))
agent = Dyna(model, n=100, alpha=.9, gamma=.9, epsilon=.1)


In [596]:
n=500

observation, info = env.reset(masspole=.45, length=1.0)
# observation, info = env.reset(seed=82, masspole=.45, length=1.0)

hist_s = [enumerate_state(discretize(observation), sizes)]
hist_a = []

for _ in range(n):
    # action = env.action_space.sample()
    action = agent.pi()[hist_s[-1]]
    # print('action: ', action)
    observation,reward, terminated, truncated, info = env.step(action)
    # print("observation : ",observation);
    
    hist_s.append(enumerate_state(discretize(observation), sizes))
    hist_a.append(int(action))

    if terminated or truncated:
        break
        # observation, info = env.reset()
        
env.close()

[(s,a) for s,a in zip(hist_s, hist_a)]

[(1558, 0),
 (1555, 0),
 (1551, 0),
 (1547, 0),
 (1543, 0),
 (1543, 0),
 (1539, 0),
 (1535, 0),
 (1531, 0),
 (1527, 0)]