# Init

In [1]:
from random import random

import gym
from gym import wrappers

import numpy as np
from scipy.spatial import KDTree

# Env Settings

In [2]:
_OBS_DIM = 4
_ACT_DIM = 1
_BIGNUM  = 1e5

# Minimum Viable Components

In [3]:
K     = 3 # ---------------------- Number of neighbors to query
N     = 0 # ---------------------- Number of exemplars
F     = np.zeros( (N,_OBS_DIM,) ) # Approximating function
V     = np.array( [] ) # ------- Action values
A     = np.zeros( (N,_ACT_DIM,) ) # Actions (Output)
KDT   = None # ------------------- Spatial tree
eps   = 1.0 # -------------------- Exploration probability
decay = 0.99
rad   = 0.0007812 # ------------------ Minimum distance between exemplars (Overwrite radius)
vMar  = 0.10 # ------------------- Allowed margin on value estimates

In [4]:
def points_from_indices( pnts, ndcs ):
    """ Get the subset of `pnts` designated by `ndcs` """
    N = len( ndcs )
    P = np.zeros( (N,pnts.shape[1],) )
    for i, idx in enumerate( ndcs ):
        P[i,:] = pnts[idx,:]
    return P

In [5]:

def cartpole_reward( X ):
    """ Reward high speed at the bottom and low speed at the top """
    # 0. Set limits
    maxThetaDot =  10.0
    maxX        =   2.0
    # 1. Set weights
    thFactor    = 100.0
    thDotFactor =   8.0
    # 2. Unpack & Normalize state
    xN        = abs( X[0] ) # Fulcrum position
    thetaN    = X[2] # ------ Angle
    thetaDotN = abs( X[3] ) # Angular velocity
    # 3. Reward high speed at the bottom and low speed at the top
    R = thFactor*np.cos(thetaN) # - thDotFactor*np.cos(thetaN)*(thetaDotN)
    R = thFactor*np.exp(np.cos(thetaN) - 1.0) # - thDotFactor*np.cos(thetaN)*(thetaDotN)
    # 4. Adjust for wandering
    if xN > maxX:
        R -= xN
    return R
    

def add_particle( state, action, value, getTree = False ):
    """ Add a new particle to the value function """
    global N, F, V, A, KDT
    
    # print( f"V-Stack: {F.shape} + {state.shape}" )
    
    if N < 1:
        F = state
        V = np.array( [value,] )
        A = action
    else:
        F = np.vstack( (F,state, ) )
        temp = V.tolist()
        temp.append( value )
        V = np.array( temp )
        A = np.vstack( (A,action,) )

    # print( f"New Values: {V}" )
    
    N += 1
    if getTree:
        return N, KDTree( F )
    else:
        return N, None
        

def recalc_spatial_tree():
    """ Recalculate spatial tree """
    global N, F, V, A, KDT
    if N > 1:
        KDT = KDTree( F )

    
def get_action_and_value_inv_dist( state ):
    """ Estimate the current optimal action and value for the state """
    global N, F, V, A, KDT
    if (N < 1) or (KDT is None):
        return None, None
    dists, indcs = KDT.query( state, K )
    dists = dists[0]
    indcs = indcs[0]
    # print( dists, indcs )
    fractV = []
    indcsV = []
    for i, d in enumerate( dists ):
        # print( d )
        if d < _BIGNUM:
            fractV.append( 1.0/d )
            indcsV.append( indcs[i] )
    fractV = np.array( fractV )
    normD  = np.linalg.norm( fractV )
    fractV = fractV / normD
    rtnAct = np.zeros( (_ACT_DIM,) )
    rtnVal = 0.0
    for i, frac in enumerate( fractV ):
        idx = indcsV[i]
        rtnAct += (A[ idx ] * frac)
        # print( V )
        rtnVal += (V[ idx ] * frac)
    return rtnAct, rtnVal

N_REF = 0
N_NEW = 0

def eval_particle( state, action, value ):
    """ Decide whetner this point represents a particle worth saving """
    global N, F, V, A, KDT, N_REF, N_NEW
    # 0. Get our estimate of the value of this state
    if N > 1:
        # print(N)
        estAct, estVal = get_action_and_value_inv_dist( state )
    else:
        estAct, estVal = None, None
    
    # 1. Find out if there is a particle there
    if KDT is not None:
        ndcs = KDT.query_ball_point( state, rad )
    else:
        ndcs = []

    # 2. If there is a particle already there and the current value is better, then update
    if len( ndcs ):
        N_REF += 1
        # print( f"[{value}, {estVal}]", end=', ' )
        if value > estVal:
            index = ndcs[0]
            # fNear = points_from_indices( F, ndcs )
            vNear = points_from_indices( V, ndcs )
            if vNear[0] < value:
                A[index,:] = action
                V[index]   = value
            if len( ndcs ) > 1:
                print( "WARNING: NEARNESS CONSTRAINT VIOLATED" )

    # 3. Elif this is an open space that does NOT estimate the value well
    elif (estVal is None) or abs(estVal - value) > abs(value * vMar):
        # print( state, action, value )
        N_NEW += 1
        add_particle( state, action, value )
        recalc_spatial_tree()
    # Else this is an open space that predicts the value well, No update!
    # N. Return the current number of particles in the estimator
    return N
    
            

# Simple Learning Test

In [6]:
EPISODES = 1000
epLen    =  500
avg_time = 0
max_time = -1
env      = gym.make( 'CartPole-v1' ).env
env      = wrappers.RecordEpisodeStatistics( env, 100 )
div      = 10

In [7]:
for i_episode in range( EPISODES ):
    # instansiating the environment
    obs = env.reset()[0].reshape( (1,_OBS_DIM,) )
    print( f"Episode {i_episode+1}: Starting at {obs}" )
    N_REF = 0
    N_NEW = 0
    for t in range( epLen ):
        # uncomment this is you want to see the rendering 
        #env.render()
        if (random() < eps) or (N<2):
            action = env.action_space.sample()
        else:
            action, preVal = get_action_and_value_inv_dist( obs )
            if action[0] >= 0.5:
                action = 1
            else:
                action = 0
            
        sLast = obs

        # print( action, end=', ' )
        obs, reward, terminated, truncated, info = env.step( action )
        # print( obs, obs.shape )
        
        reward = cartpole_reward( obs )
        
        obs    = obs.reshape( (1,_OBS_DIM,) )
        eval_particle( sLast, action, reward )
        
        if terminated:
            avg_time = avg_time + t
            if t > max_time:
                max_time = t
                # print( f"\tMax. Uptime: {max_time}" )
            #print("Episode finished after {} timesteps".format(t+1))
            break
    # resetting the enviroment
    env.reset()
    eps *= decay
    if (i_episode>1) and ((i_episode%div) == 0):
        print( f'\navg time agent survives : {avg_time/i_episode}, There are {N} particles, {N_REF} references and {N_NEW} additions\n' )
        

# printing the avg time the game lasted
avg_time = avg_time/EPISODES
print( '\navg time agent survives :', avg_time, '\n' )

Episode 1: Starting at [[-0.00706559  0.02780364 -0.04282921  0.03066145]]
Episode 2: Starting at [[-0.0455995   0.0395777  -0.00440477 -0.02827299]]
Episode 3: Starting at [[-0.04234758 -0.01884432  0.01906625 -0.02803242]]
Episode 4: Starting at [[0.01967673 0.04920952 0.02439643 0.02464868]]
Episode 5: Starting at [[ 0.03608368 -0.01246495 -0.01486861 -0.04445079]]
Episode 6: Starting at [[-0.03708541  0.01309603 -0.04695722 -0.03617593]]
Episode 7: Starting at [[-0.0415691  -0.00403561 -0.01697205 -0.02003552]]
Episode 8: Starting at [[ 0.00349656  0.01627246 -0.02666806  0.03393168]]
Episode 9: Starting at [[-0.01079308 -0.03299043  0.00315026  0.01930564]]
Episode 10: Starting at [[ 0.04380033 -0.02685823  0.0017688  -0.03348565]]
Episode 11: Starting at [[ 0.04802436 -0.04413629  0.03143498  0.03008341]]

avg time agent survives : 22.3, There are 2 particles, 11 references and 0 additions

Episode 12: Starting at [[ 0.00371436  0.00140088 -0.02200142 -0.01117432]]
Episode 13: St