### 首先考虑离散的 State、Action 空间组成的Q函数

In [3]:
from collections import defaultdict
from typing import Callable, List, Tuple, Optional

import numpy as np
from tqdm import tqdm

State = int
Action = int
Reward = float
ActionProbDistribution = List[float]

class AbstractQFunc():
    def get_value(self, state: State, action: Action) -> float:
        raise NotImplementedError()
    
    def get_action_distribute(self, state: State) -> ActionProbDistribution:
        raise NotImplementedError()

    def get_actions_count(self) -> int:
        raise NotImplementedError()
    
    def set_value(self, state: State, action: Action, value: float) -> None:
        raise NotImplementedError()

class DiscreteQFunc(AbstractQFunc):
    def __init__(self, state_nums: int, action_nums: int) -> None:
        self._q_table = defaultdict(lambda : np.zeros(action_nums, dtype=np.float32))
        self._state_nums = state_nums 
        self._action_nums = action_nums

    def get_value(self, state, action) -> float:
        return self._q_table[state][action]

    def set_value(self, state: State, action: Action, value: float) -> None:
        self._q_table[state][action] = value

    def get_action_distribute(self, state: State) -> ActionProbDistribution:
        return self._q_table[state]

    def get_actions_count(self) -> int:
        return self._action_nums

### 我们定义策略函数Pi(s) = P(a | s)；策略函数实际返回一个Action空间的分布函数，在离散的情况下，我们用一个数组表示这个分布， 下面定义一组函数，用于将Q转换为对应的策略

In [6]:


# 策略函数
# todo: change the right type
ActionProbDistribution = List[float]
Strategy = Callable[[State], ActionProbDistribution]


def to_strategy(f: AbstractQFunc) -> Strategy:
    def _strategy(s: State) -> ActionProbDistribution:
        return f.get_action_distribute(s)

def to_strategy_epsilon_greedy(f: AbstractQFunc, epsilon: float) -> Strategy:
    def _strategy(s: State) -> ActionProbDistribution:
        # e-greedy 策略
        if np.random.uniform(0, 1) > epsilon:
            return f.get_action_distribute(s)
        else:
            # 随机选择动作 
            return np.ones_like(f.get_actions_count(), dtype=np.float32) / f.get_actions_count()
    return _strategy


### 最后是训练流程，在一个环境中，首先根据当前环境进行决策，再执行动作&观察反馈，最后根据信息更新

In [5]:
class AbstractEnv():
    # 如果返回的State部分是None，则表示Terminal 状态
    def step(self, action: Action) ->  Tuple[Reward, Optional[State]]: 
        raise NotImplementedError()
    
    def reset(self) -> State:
        return NotImplementedError()
    

class AbstractTrain():
    def train(self):
        raise NotImplementedError()
    
    

In [6]:
# 我们实现一个使用 epsilon-greedy 策略的Q-Learning 训练。（ps， 只针对离散的Q Learning）
class QLearningTrain(AbstractTrain):
    def __init__(self, gamma: float, learning_rate: float, epoch: int, epsilon_list: List[float],
                 q_func: AbstractQFunc, env: AbstractEnv):
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.epoch = epoch
        self.epsilon_list = epsilon_list
        
        self.q_func = q_func
        self.env = env
        
        self.current_state = None
    def train(self):
        for epoch in tqdm(range(self.epoch)):
            init_state = self.env.reset()  # 回合开始前先重制环境          
            self.current_state = init_state
            while True: # 复杂的环境设置最大步数，也就是Horizon
                # 获取此时Q 对应的epsilon-greedy 的策略 
                e_greedy_s = to_strategy_epsilon_greedy(self.q_func, self.epsilon_list[epoch])
                # 使用此时的策略进行决策
                action = e_greedy_s(self.current_state)
                # 执行此时的action
                reward, next_state = self.env.step(action)
                if next_state is None:
                    # 达到terminal状态
                    q_target = reward 
                else:
                    q_target = reward + self.gamma * np.argmax(self.q_func.get_action_distribute(next_state))
                
                # 更新Q 函数
                current_value = self.q_func.get_value(self.current_state, action)
                self.q_func.set_value(self.current_state, action, 
                                       current_value + self.learning_rate * (q_target - current_value)
                                    )
 