Skip to content
A pack of reinforcement learning algorithms.
Python
Branch: master
Clone or download
Fetching latest commit…
Cannot retrieve the latest commit at this time.
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
doc
examples
figures
rlpack
.gitignore
README.md
environment.yml
setup.py

README.md

本包简介


rlpack是一个基于tensorflow的强化学习算法库,解耦算法和环境,方便调用。

使用方法


下面展示如何使用rlpackMuJoCo环境中运行PPO算法。

# -*- coding: utf-8 -*-


import argparse
import time
from collections import namedtuple

import gym
import numpy as np
import tensorflow as tf

from rlpack.algos import PPO
from rlpack.utils import mlp, mlp_gaussian_policy

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--env',  type=str, default="Reacher-v2")
args = parser.parse_args()

Transition = namedtuple('Transition', ('state', 'action', 'reward', 'done', 'early_stop', 'next_state'))


class Memory(object):
    def __init__(self):
        self.memory = []

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self):
        return Transition(*zip(*self.memory))


def policy_fn(x, a):
    return mlp_gaussian_policy(x, a, hidden_sizes=[64, 64], activation=tf.tanh)


def value_fn(x):
    v = mlp(x, [64, 64, 1])
    return tf.squeeze(v, axis=1)


def run_main():
    env = gym.make(args.env)
    dim_obs = env.observation_space.shape[0]
    dim_act = env.action_space.shape[0]
    max_ep_len = 1000

    agent = PPO(dim_act=dim_act, dim_obs=dim_obs, policy_fn=policy_fn, value_fn=value_fn, save_path="./log/ppo")

    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0
    for epoch in range(50):
        memory, ep_ret_list, ep_len_list = Memory(), [], []
        for t in range(1000):
            a = agent.get_action(o[np.newaxis, :])[0]
            nexto, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1

            memory.push(o, a, r, int(d), int(ep_len == max_ep_len or t == 1000-1), nexto)

            o = nexto

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == 1000-1):
                if not(terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' % ep_len)
                if terminal:
                    # 当到达完结状态或是最长状态时,记录结果
                    ep_ret_list.append(ep_ret)
                    ep_len_list.append(ep_len)
                o, ep_ret, ep_len = env.reset(), 0, 0

        print(f"{epoch}th epoch. average_return={np.mean(ep_ret_list)}, average_len={np.mean(ep_len_list)}")

        # 更新策略。
        batch = memory.sample()
        agent.update([np.array(x) for x in batch])

    elapsed_time = time.time() - start_time
    print("elapsed time:", elapsed_time)


if __name__ == "__main__":
    run_main()

安装流程


  1. 安装依赖包

安装所需依赖软件包,请看environment.yml. 建议使用Anaconda配置python运行环境,可用以下脚本安装。

    $ git clone https://github.com/liber145/rlpack
    $ cd rlpack
    $ conda env create -f environment.yml
    $ conda activate py36
  1. 安装rlpack
    $ python setup.py install

以上流程会安装一个常用的强化学习运行环境gym. 该环境还支持一些复杂的强化学习环境,比如MuJoCo,具体请看gym的介绍。

算法列表


算法 论文链接 类型 连续动作 离散动作
DQN Playing Atari with Deep Reinforcement Learning off-policy
DoubleDQN Deep Reinforcement Learning with Double Q-learning off-policy
DuelDQN Dueling Network Architectures for Deep Reinforcement Learning off-policy
DistDQN A Distributional Perspective on Reinforcement Learning off-policy
PG Introduction to Reinforcement Learning on-policy
A2C Asynchronous Methods for Deep Reinforcement Learning on-policy
TRPO Trust Region Policy Optimization on-policy
PPO Proximal Policy Optimization Algorithms on-policy
TD3 Addressing Function Approximation Error in Actor-Critic Methods off-policy
DDPG Continuous control with deep reinforcement learning off-policy
SAC Soft Actor-Critic off-policy

部分算法解释请看文档

参考代码


在实现过程中,参考了其他优秀代码,帮助比较大的列举如下:

学习资料


You can’t perform that action at this time.