In [1]:
import os, sys
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
%matplotlib inline
import gym
from gym import spaces
from gym.utils import seeding

In [2]:
gym.envs.register(
    id='NChain-v1',
    entry_point='gym.envs.toy_text:NChainEnv',
    kwargs={'n':10},
    max_episode_steps=100,
)

In [84]:
class SRAgent:
    def __init__(self, n_state, n_action, beta=2.0, epsilon = 0.1):
        self.n_state = n_state
        self.n_action = n_action
        self.sr_mat = np.zeros((n_state, n_state)).astype("float")
        self.mean_r = np.zeros(n_state).astype("float")
        self.Q = np.zeros((n_state, n_action)).astype("float")
        self.beta = beta
        self.epsilon = epsilon
    
    def choose_action(self, current_s, mode="epsilon"):
        if mode == "epsilon":
            if np.random.sample() < self.epsilon:
                action = np.random.choice(env.action_space.n)
            else:
                action = np.argmax(self.Q[current_s])
        else:
            action = np.argmax(self.Q[current_s])
        return action
    
    def update_Q(self, state, action):
        self.Q[state, action] = np.dot(self.sr_mat[state, :], self.mean_r)
    
    def td_update(self, state, next_state, reward, lr=0.01):
        td_error = reward + np.dot(self.sr_mat[next_state, :], self.mean_r) - np.dot(self.sr_mat[state, :], self.mean_r)
        self.mean_r = self.mean_r +  lr * td_error * self.sr_mat[state, :]
    
    def update_SR_mat(self, state, next_state, lr=0.01, gamma=0.99):
        M_target = np.eye(self.n_state)[state] + gamma * self.sr_mat[next_state,:]
        self.sr_mat[state, :] = self.sr_mat[state, :] + lr * (M_target - self.sr_mat[state, :])

In [85]:
n_episodes = 10
lr = 0.01
gamma = 0.95
b_policy = "epsilon"

In [86]:
env = gym.make("NChain-v0")
# env = gym.make("FrozenLake8x8-v0")
agent = SRAgent(env.observation_space.n, env.action_space.n)

In [87]:
for i in range(n_episodes):
    state = env.reset()
    c_step = 0
    done = False
    c_rewards = 0.0
    while done == False:
        action = agent.choose_action(state)
        next_state, reward, done, _ = env.step(action)
        agent.td_update(state, next_state, reward, lr=lr)
        agent.update_SR_mat(state, next_state, lr=lr, gamma=gamma)
        agent.update_Q(next_state, action)
        state = next_state
        c_rewards = c_rewards + reward
#         print(agent.mean_r)
#         print(agent.sr_mat)
#         print(agent.Q)
#         print("\n\n")
    print("episode "+str(i) + ":", c_rewards)

episode 0: 2412.0
episode 1: 2104.0
episode 2: 1904.0
episode 3: 2170.0
episode 4: 1774.0
episode 5: 1938.0
episode 6: 2050.0
episode 7: 1996.0
episode 8: 1708.0
episode 9: 2038.0


In [88]:
print(agent.mean_r)
print(agent.sr_mat)
print(agent.Q)

[ 93.39207304 106.52640077 150.03506566 209.56984283 170.45421786]
[[7.96910355 2.6504179  1.0305965  0.51099141 0.84496204]
 [6.55524514 3.13834113 1.17709475 0.58985681 0.9851063 ]
 [5.65623523 1.80924258 1.63253659 0.86597424 1.49042151]
 [5.4894196  1.75336052 0.60265535 1.2581616  2.25146995]
 [5.9670853  1.92263476 0.68856911 0.31497967 3.42689528]]
[[1427.52879252 1456.47201133]
 [1422.3661546  1421.81070493]
 [1401.44808663 1398.9206568 ]
 [1448.17742397 1285.42885315]
 [1567.69033625 1603.59918233]]


In [43]:
[[0.0]]*6

[[0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]