In [1]:
!pip3 install rl_util


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2[0m[39;49m -> [0m[32;49m22.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3.9 -m pip install --upgrade pip[0m


In [2]:
import pandas as pd
import jax
import jax.numpy as jnp
import random
from rl_util.test import test_policy
from rl_util.value import QFunction
from rl_util.environment import MarkovEnv
from rl_util.policy import EpsSoftPolicy, EpsSoftPolicyFromQ
from rl_util.generator import simple_circle
import numpy as np

S = 'state'
A = 'action'
R = 'reward'
V = 'value'
G = 'return'

In [4]:
class EpsSoftPolicyFromQs(EpsSoftPolicy):
    def __init__(self, qs, state_space: int, action_space: int, eps: float):
        super().__init__(state_space, action_space, eps)
        self.qs = qs

    def update(self, s, a):
        raise Exception(':(')

    def p(self, a, s):
        best_a = None
        best_v = float('-inf')
        for a in range(self.action_space):
            cur_v = sum([q.loc[(q[S] == s) & (q[A] == a)][V].values[0] for q in self.qs])
            if cur_v > best_v:
                best_v = cur_v
                best_a = a
                
        if a == best_a:
            return 1 - self.eps + self.eps / self.action_space
        else:
            return self.eps / self.action_space

    def __call__(self, s):
        best_a = None
        best_v = float('-inf')
        for a in range(self.action_space):
            cur_v = sum([q.loc[(q[S] == s) & (q[A] == a)][V].values[0] for q in self.qs])
            if cur_v > best_v:
                best_v = cur_v
                best_a = a
        probs = [self.eps / self.action_space for _ in range(self.action_space)]
        probs[best_a] = 1 - self.eps + self.eps / self.action_space
        return random.choices(list(range(self.action_space)), probs, k=1)[0]

# SARSA (on-policy TD control)
state-action-reward-state-action

In [6]:
def sarsa(alpha, phi, eps, env, iterations):
    q = QFunction(env)
    policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    for _ in range(iterations):
        state = env.reset()
        done = False
        action = policy(state)
        while not done:
            next_state, reward, done = env.step(action)
            q_val = q(state, action)
            if done:
                q_val_next = 0
                next_action = None
            else:
                next_action = policy(next_state)
                q_val_next = q(next_state, next_action)
            
            q.update(state, action, q_val + alpha * (reward + phi * q_val_next - q_val))
            state, action = next_state, next_action
            policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    return policy, q

# Q-learning (off-policy TD control)

In [7]:
def q_learning(alpha, phi, eps, env, iterations):
    q = QFunction(env)
    policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    for _ in range(iterations):
        state = env.reset()
        done = False
        while not done:
            action = policy(state)
            next_state, reward, done = env.step(action)
            
            q_val = q(state, action)

            if done:
                q_val_next = 0
            else:
                q_val_next = q.get_max(next_state)
            q.update(state, action, q_val + alpha * (reward + phi * q_val_next - q_val))
            state = next_state
            policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    return policy, q

# Expected SARSA

In [8]:
def expected_sarsa(alpha, phi, eps, env, iterations):
    q = QFunction(env)
    policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    for _ in range(iterations):
        state = env.reset()
        done = False
        while not done:
            action = policy(state)
            next_state, reward, done = env.step(action)
            
            q_val = q(state, action)
            
            if done:
                q_val_next = 0
            else:
                q_next = q.q.loc[(q.q[S] == next_state)]
                q_val_next = 0
                for (next_action, value) in zip(q_next[A], q_next[V]):
                    q_val_next += policy.p(next_action, state) * value
            
            q.update(state, action, q_val + alpha * (reward + phi * q_val_next - q_val))
            state = next_state
            policy = EpsSoftPolicyFromQ(q.q, state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    return policy, q

# Double Q-learning

In [9]:
def double_q_learning(alpha, phi, eps, env, iterations):
    qs = [QFunction(env), QFunction(env)]
    policy = EpsSoftPolicyFromQs([q.q for q in qs], state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    for _ in range(iterations):
        state = env.reset()
        done = False
        while not done:
            action = policy(state)
            next_state, reward, done = env.step(action)
            
            q_index = random.randint(0, 1)
            q1, q2 = qs[q_index], qs[1 - q_index]
            q_val = q1(state, action)
            
            if done:
                q_val_next = 0
            else:
                next_action = q1.q.iloc[q1.q.loc[(q1.q[S] == next_state)][V].idxmax()][A]
                q_val_next = q2(next_state, next_action)
            
            q1.update(state, action, q_val + alpha * (reward + phi * q_val_next - q_val))
            state = next_state
            policy = EpsSoftPolicyFromQs([q.q for q in qs], state_space=env.state_space(), action_space=env.action_space(), eps=eps)
    return policy, qs

# Tests

In [10]:
env = simple_circle(state_space=10, action_space=2)
alpha = 0.1
phi = 0.99
eps = 0.5
iterations = 100

In [11]:
env.transitions

Unnamed: 0,state,action,reward,next_state,probability
0,0.0,0.0,-3.0,1.0,1.0
1,0.0,1.0,-3.0,1.0,1.0
2,1.0,0.0,-2.0,2.0,1.0
3,1.0,1.0,-3.0,1.0,1.0
4,2.0,0.0,-3.0,3.0,1.0
5,2.0,1.0,-3.0,7.0,1.0
6,3.0,0.0,-1.0,4.0,1.0
7,3.0,1.0,-1.0,3.0,1.0
8,4.0,0.0,-2.0,5.0,1.0
9,4.0,1.0,-1.0,9.0,1.0


In [14]:
# Sarsa
sarsa_policy, sarsa_q = sarsa(alpha, phi, eps, env, iterations)
test_policy(env, sarsa_policy)

Finished in 5 steps, reward: -12.0


([0, 1, 2, 7, 8, 9], -12.0, 5)

In [15]:
sarsa_q.q

Unnamed: 0,state,action,value
0,0,0,-11.051901
1,0,1,-10.601612
2,1,0,-7.564619
3,1,1,-10.733962
4,2,0,-5.640143
5,2,1,-6.585432
6,3,0,-2.468018
7,3,1,-3.335235
8,4,0,-3.119355
9,4,1,-0.99807


In [16]:
# Q-learning
q_policy, q_q = q_learning(alpha, phi, eps, env, iterations)
test_policy(env, q_policy)

Finished in 5 steps, reward: -12.0


([0, 1, 2, 7, 8, 9], -12.0, 5)

In [17]:
q_q.q

Unnamed: 0,state,action,value
0,0,0,-9.599068
1,0,1,-9.623416
2,1,0,-6.881129
3,1,1,-9.085958
4,2,0,-4.950152
5,2,1,-6.09477
6,3,0,-1.985599
7,3,1,-2.893946
8,4,0,-1.454781
9,4,1,-0.999363


In [18]:
# Expected SARSA
es_policy, es_q = expected_sarsa(alpha, phi, eps, env, iterations)
test_policy(env, es_policy)

Finished in 11 steps, reward: -22.0


([0, 1, 1, 1, 2, 3, 3, 4, 5, 6, 7, 8], -22.0, 11)

In [19]:
es_q.q

Unnamed: 0,state,action,value
0,0,0,-12.004017
1,0,1,-12.166895
2,1,0,-8.683246
3,1,1,-11.764461
4,2,0,-7.001794
5,2,1,-7.193707
6,3,0,-4.391615
7,3,1,-4.684123
8,4,0,-4.712814
9,4,1,-0.997991


In [20]:
# Double q learning
double_policy, qs = double_q_learning(alpha, phi, eps, env, iterations)
test_policy(env, double_policy)

Finished in 6 steps, reward: -15.0


([0, 1, 2, 7, 8, 8, 9], -15.0, 6)

In [21]:
qs[0].q

Unnamed: 0,state,action,value
0,0,0,-8.235419
1,0,1,-8.214445
2,1,0,-6.250262
3,1,1,-8.465102
4,2,0,-4.559531
5,2,1,-4.994167
6,3,0,-1.829302
7,3,1,-2.329924
8,4,0,-2.823172
9,4,1,-0.917952


In [22]:
qs[1].q

Unnamed: 0,state,action,value
0,0,0,-8.317341
1,0,1,-8.305963
2,1,0,-6.763737
3,1,1,-7.835273
4,2,0,-4.542935
5,2,1,-4.248458
6,3,0,-1.791456
7,3,1,-1.91606
8,4,0,-2.339919
9,4,1,-0.954555
