In [1]:
!pip3 install rl_util


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2[0m[39;49m -> [0m[32;49m22.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.9 -m pip install --upgrade pip[0m


In [2]:
import pandas as pd
import jax
import jax.numpy as jnp
import random
from rl_util.test import test_policy
from rl_util.value import QFunction
from rl_util.environment import MarkovEnv
from rl_util.policy import EpsSoftPolicy, EpsSoftPolicyFromQ, GreedyPolicyFromQ
from rl_util.generator import simple_circle
import numpy as np

S = 'state'
A = 'action'
R = 'reward'
V = 'value'
G = 'return'
NS = 'next_step'

# On-policy N-step SARSA

In [139]:
def n_step_sarsa(n, alpha, phi, eps, env, iterations):
    q = QFunction(env)
    policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    for _ in range(iterations):
        state = env.reset()
        action = policy(state)
        T = float('inf')
        done = False
        t = 0
        trace = [{S: state, A: action, R: 0}]
        while True:
            if t < T:
                next_state, reward, done = env.step(action)
                if done:
                    T = t + 1
                    action = None
                else:
                    action = policy(next_state)
                    state = next_state
                trace.append({R: reward, A: action, S: state})
            tau = t - n + 1
            if tau >= 0:
                g = sum([trace[i][R] * (phi ** (i - tau - 1)) for i in range(tau + 1, min(T, tau + n))])
                if tau + n < T:
                    g = g + (phi ** n) * q(trace[tau + n][S], trace[tau + n][A])
                q.update(trace[tau][S], trace[tau][A], q(trace[tau][S], trace[tau][A]) + alpha * (g - q(trace[tau][S], trace[tau][A])))
                policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
            if tau + 1 == T:
                break
            t += 1
    return GreedyPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space()), q

# Off-policy N-step SARSA

In [138]:
def off_policy_n_step_sarsa(b_policy, n, alpha, phi, eps, env, iterations):
    q = QFunction(env)
    target_policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    for _ in range(iterations):
        state = env.reset()
        action = b_policy(state)
        T = float('inf')
        done = False
        t = 0
        trace = [{R: 0, S: state, A: action}]
        while True:
            if t < T:
                next_state, reward, done = env.step(action)
                if done:
                    T = t + 1
                    action = None
                else:
                    action = b_policy(next_state)
                    state = next_state
                trace.append({R: reward, A: action, S: state})
            tau = t - n + 1
            if tau >= 0:
                ro = 1
                for i in range(tau + 1,  min(T - 1, tau + n - 1)):
                    ro *= (target_policy.p(trace[i][A], trace[i][S]) / b_policy.p(trace[i][A], trace[i][S])) # importance sampling ratio
                g = sum([trace[i][R] * (phi ** (i - tau - 1)) for i in range(tau + 1, min(T, tau + n))])
                if tau + n < T and tau + n < len(trace):
                    g = g + (phi ** n) * q(trace[tau + n][S], trace[tau + n][A])
                q.update(trace[tau][S], trace[tau][A], q(trace[tau][S], trace[tau][A]) + alpha * ro * (g - q(trace[tau][S], trace[tau][A])))
            target_policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
            if tau + 1 == T:
                break
            t += 1
    return GreedyPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space()), q

# Tree Backup

In [184]:
def tree_backup(n, alpha, phi, eps, env, iterations):
    q = QFunction(env)
    policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    for _ in range(iterations):
        state = env.reset()
        action = policy(state)
        T = float('inf')
        done = False
        t = 0
        trace = [{R: 0, S: state, A: action}]
        while True:
            if t < T:
                next_state, reward, done = env.step(action)
                if done:
                    T = t + 1
                    action = None
                else:
                    action = random.randint(0, env.action_space() - 1)
                    state = next_state
                trace.append({R: reward, A: action, S: state})
            tau = t - n + 1
            if tau >= 0:
                if t + 1 >= T:
                    g = trace[T][R]
                else:
                    g = trace[t + 1][R] + phi * sum([policy.p(a, trace[t + 1][S]) * q(trace[t + 1][S], a) for a in range(env.action_space())])
                for k in range(min(t, T - 1), tau, -1):
                    g = trace[k][R] + phi * (sum([policy.p(a, trace[k][S]) * q(trace[k][S], a) * (0 if a == trace[k][A] else 1) for a in range(env.action_space())]) + policy.p(trace[k][A], trace[k][S]) * q(trace[k][S], trace[k][A]) * g)
                q.update(
                    trace[tau][S], 
                    trace[tau][A], 
                    q(trace[tau][S], trace[tau][A]) + alpha * (g - q(trace[tau][S], trace[tau][A]))
                )
            policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
            if tau + 1 == T:
                break
            t += 1
    return GreedyPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space()), q

# N-step Q(omega) with degree of sampling

In [203]:
def n_step_q_omega(b_policy, n, alpha, phi, eps, env, iterations):
    q = QFunction(env)
    target_policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    for _ in range(iterations):
        state = env.reset()
        action = b_policy(state)
        T = float('inf')
        done = False
        t = 0
        omegas, ros = [0], [0]
        trace = [{R: 0, S: state, A: action}]
        while True:
            if t < T:
                next_state, reward, done = env.step(action)
                if done:
                    T = t + 1
                    action = None
                else:
                    action = b_policy(next_state)
                    state = next_state
                    
                    omega = random.uniform(0, 1)
                    omegas.append(omega)
                    ro = target_policy.p(action, state) / b_policy.p(action, state)
                    ros.append(ro)
                trace.append({R: reward, A: action, S: state})
            tau = t - n + 1
            if tau >= 0:
                g = 0
                for k in range(min(t + 1, T), tau, -1):
                    if k == T:
                        g = trace[k][R]
                    else:
                        v = sum([target_policy.p(a, trace[k][S]) * q(trace[k][S], a) for a in range(env.action_space())])
                        g = trace[k][R] + phi * (ros[k] * omegas[k] + (1 - omegas[k]) * target_policy.p(trace[k][A], trace[k][S])) * (g - q(trace[k][S], trace[k][A])) + phi * v
                q.update(
                    trace[tau][S], 
                    trace[tau][A], 
                    q(trace[tau][S], trace[tau][A]) + alpha * (g - q(trace[tau][S], trace[tau][A]))
                )
            target_policy = EpsSoftPolicyFromQ(q.q, eps=eps, state_space=env.state_space(), action_space=env.action_space())
            if tau + 1 == T:
                break
            t += 1
    return GreedyPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space()), q

# Test

In [189]:
env = simple_circle(state_space=4, action_space=2)
alpha = 0.1
phi = 0.99
eps = 0.5
iterations = 15
n = 5

In [190]:
env.transitions

Unnamed: 0,state,action,reward,next_state,probability
0,0.0,0.0,-1.0,1.0,1.0
1,0.0,1.0,-2.0,3.0,1.0
2,1.0,0.0,-3.0,2.0,1.0
3,1.0,1.0,-3.0,0.0,1.0
4,2.0,0.0,-3.0,3.0,1.0
5,2.0,1.0,-3.0,2.0,1.0


In [191]:
# On-policy sarsa
nss_policy, q = n_step_sarsa(n, alpha, phi, eps, env, iterations)
test_policy(env, nss_policy)

Finished in 1 steps, reward: -2.0


[0, 3]

In [200]:
# Off-policy sarsa
opnss_policy, q = off_policy_n_step_sarsa(nss_policy, n, alpha, phi, eps, env, 8)
test_policy(env, opnss_policy)

Finished in 1 steps, reward: -2.0


[0, 3]

In [201]:
# Tree backup
tb_policy, q = tree_backup(n, alpha, phi, eps, env, 1)
test_policy(env, tb_policy)

Finished in 1 steps, reward: -2.0


[0, 3]

In [202]:
# Q(omega)
qo_policy, q = n_step_q_omega(nss_policy, 10, alpha, phi, eps, env, 1)
test_policy(env, qo_policy)

UnboundLocalError: local variable 'ro' referenced before assignment