diff --git a/.github/workflows/algo_test.yml b/.github/workflows/algo_test.yml
index e61fe104aa..02f46cd9a1 100644
--- a/.github/workflows/algo_test.yml
+++ b/.github/workflows/algo_test.yml
@@ -31,5 +31,6 @@ jobs:
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
+ python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make algotest
diff --git a/.github/workflows/envpool_test.yml b/.github/workflows/envpool_test.yml
index 66ea2f8d11..e14a4848a0 100644
--- a/.github/workflows/envpool_test.yml
+++ b/.github/workflows/envpool_test.yml
@@ -24,5 +24,6 @@ jobs:
python -m pip install .
python -m pip install ".[test,k8s]"
python -m pip install ".[envpool]"
+ python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make envpooltest
diff --git a/.github/workflows/platform_test.yml b/.github/workflows/platform_test.yml
index 2194ab027c..9857639b3b 100644
--- a/.github/workflows/platform_test.yml
+++ b/.github/workflows/platform_test.yml
@@ -25,5 +25,6 @@ jobs:
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
+ python -m pip install transformers
python -m pip uninstall pytest-timeouts -y
make platformtest
diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml
index d58a863759..771dc596e3 100644
--- a/.github/workflows/unit_test.yml
+++ b/.github/workflows/unit_test.yml
@@ -25,6 +25,7 @@ jobs:
python -m pip install box2d-py
python -m pip install .
python -m pip install ".[test,k8s]"
+ python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make unittest
- name: Upload coverage to Codecov
@@ -53,5 +54,6 @@ jobs:
run: |
python -m pip install .
python -m pip install ".[test,k8s]"
+ python -m pip install transformers
./ding/scripts/install-k8s-tools.sh
make benchmark
diff --git a/README.md b/README.md
index a7f0e13c3b..7c26010214 100644
--- a/README.md
+++ b/README.md
@@ -86,7 +86,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs
- [awesome-diffusion-model-in-rl](https://github.com/opendilab/awesome-diffusion-model-in-rl): A curated list of Diffusion Model in RL resources
- [awesome-end-to-end-autonomous-driving](https://github.com/opendilab/awesome-end-to-end-autonomous-driving): A curated list of awesome End-to-End Autonomous Driving resources
- [awesome-driving-behavior-prediction](https://github.com/opendilab/awesome-driving-behavior-prediction): A collection of research papers for Driving Behavior Prediction
-
+
On the low-level end, DI-engine comes with a set of highly re-usable modules, including [RL optimization functions](https://github.com/opendilab/DI-engine/tree/main/ding/rl_utils), [PyTorch utilities](https://github.com/opendilab/DI-engine/tree/main/ding/torch_utils) and [auxiliary tools](https://github.com/opendilab/DI-engine/tree/main/ding/utils).
@@ -210,51 +210,52 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 7 | [SQL](https://arxiv.org/pdf/1702.08165.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [SQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sql.html)
[policy/sql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sql.py) | ding -m serial -c cartpole_sql_config.py -s 0 |
| 8 | [R2D2](https://openreview.net/forum?id=r1lyTjAqYX) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [R2D2 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d2.html)
[policy/r2d2](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d2.py) | ding -m serial -c cartpole_r2d2_config.py -s 0 |
| 9 | [PG](https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)
[policy/pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pg.py) | ding -m serial -c cartpole_pg_config.py -s 0 |
-| 10 | [A2C](https://arxiv.org/pdf/1602.01783.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [A2C doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)
[policy/a2c](https://github.com/opendilab/DI-engine/blob/main/ding/policy/a2c.py) | ding -m serial -c cartpole_a2c_config.py -s 0 |
-| 11 | [PPO](https://arxiv.org/abs/1707.06347)/[MAPPO](https://arxiv.org/pdf/2103.01955.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [PPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppo.html)
[policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | python3 -u cartpole_ppo_main.py / ding -m serial_onpolicy -c cartpole_ppo_config.py -s 0 |
-| 12 | [PPG](https://arxiv.org/pdf/2009.04416.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppg.html)
[policy/ppg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppg.py) | python3 -u cartpole_ppg_main.py |
-| 13 | [ACER](https://arxiv.org/pdf/1611.01224.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [ACER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/acer.html)
[policy/acer](https://github.com/opendilab/DI-engine/blob/main/ding/policy/acer.py) | ding -m serial -c cartpole_acer_config.py -s 0 |
-| 14 | [IMPALA](https://arxiv.org/abs/1802.01561) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IMPALA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/impala.html)
[policy/impala](https://github.com/opendilab/DI-engine/blob/main/ding/policy/impala.py) | ding -m serial -c cartpole_impala_config.py -s 0 |
-| 15 | [DDPG](https://arxiv.org/pdf/1509.02971.pdf)/[PADDPG](https://arxiv.org/pdf/1511.04143.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [DDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)
[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c pendulum_ddpg_config.py -s 0 |
-| 16 | [TD3](https://arxiv.org/pdf/1802.09477.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [TD3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3.html)
[policy/td3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3.py) | python3 -u pendulum_td3_main.py / ding -m serial -c pendulum_td3_config.py -s 0 |
-| 17 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [D4PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/d4pg.html)
[policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
-| 18 | [SAC](https://arxiv.org/abs/1801.01290)/[MASAC] | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [SAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sac.html)
[policy/sac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sac.py) | ding -m serial -c pendulum_sac_config.py -s 0 |
-| 19 | [PDQN](https://arxiv.org/pdf/1810.06394.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_pdqn_config.py -s 0 |
-| 20 | [MPDQN](https://arxiv.org/pdf/1905.04388.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_mpdqn_config.py -s 0 |
-| 21 | [HPPO](https://arxiv.org/pdf/1903.01344.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | ding -m serial_onpolicy -c gym_hybrid_hppo_config.py -s 0 |
-| 22 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py |
-| 23 | [MDQN](https://arxiv.org/abs/2007.14430) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/mdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mdqn.py) | python3 -u asterix_mdqn_config.py |
-| 24 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [QMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qmix.html)
[policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 |
-| 25 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [COMA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/coma.html)
[policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 |
-| 26 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 |
-| 27 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [WQMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/wqmix.html)
[policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 |
-| 28 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [CollaQ doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/collaq.html)
[policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
-| 29 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)
[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ant_maddpg_config.py -s 0 |
-| 30 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [GAIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/gail.html)
[reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
-| 31 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [SQIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sqil.html)
[entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
-| 32 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [DQFD doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqfd.html)
[policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
-| 33 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [R2D3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d3.html)
[R2D3中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html)
[policy/r2d3](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html) | python3 -u pong_r2d3_r2d2expert_config.py |
-| 34 | [Guided Cost Learning](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [Guided Cost Learning中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/guided_cost_zh.html)
[reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py |
-| 35 | [TREX](https://arxiv.org/abs/1904.06387) | ![IL](https://img.shields.io/badge/-IL-purple) | [TREX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/trex.html)
[reward_model/trex](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/trex_reward_model.py) | python3 mujoco_trex_main.py |
-| 36 | [Implicit Behavorial Cloning](https://implicitbc.github.io/) (DFO+MCMC) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ibc.py)
[model/template/ebm](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/ebm.py) | python3 d4rl_ibc_main.py -s 0 -c pen_human_ibc_mcmc_config.py |
-| 37 | [BCO](https://arxiv.org/pdf/1805.01954.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/bco](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_bco.py) | python3 -u cartpole_bco_config.py |
-| 38 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [HER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/her.html)
[reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
-| 39 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [RND doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rnd.html)
[reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_rnd_onppo_config.py |
-| 40 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [ICM doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/icm.html)
[ICM中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/icm_zh.html)
[reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
-| 41 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [CQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/cql.html)
[policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
-| 42 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)
[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py |
-| 43 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dt.py) | python3 -u ding/example/dt.py |
-| 44 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)
[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
-| 45 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
-| 46 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
-| 47 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)
[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
-| 48 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
-| 49 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
-| 50 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
-| 51 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
-| 52 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
-| 53 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)
[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
-| 54 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
+| 10 | [PromptPG](https://arxiv.org/abs/2209.14610) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/prompt_pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_pg.py) | ding -m serial_onpolicy -c tabmwp_pg_config.py -s 0 |
+| 11 | [A2C](https://arxiv.org/pdf/1602.01783.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [A2C doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)
[policy/a2c](https://github.com/opendilab/DI-engine/blob/main/ding/policy/a2c.py) | ding -m serial -c cartpole_a2c_config.py -s 0 |
+| 12 | [PPO](https://arxiv.org/abs/1707.06347)/[MAPPO](https://arxiv.org/pdf/2103.01955.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [PPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppo.html)
[policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | python3 -u cartpole_ppo_main.py / ding -m serial_onpolicy -c cartpole_ppo_config.py -s 0 |
+| 13 | [PPG](https://arxiv.org/pdf/2009.04416.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppg.html)
[policy/ppg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppg.py) | python3 -u cartpole_ppg_main.py |
+| 14 | [ACER](https://arxiv.org/pdf/1611.01224.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [ACER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/acer.html)
[policy/acer](https://github.com/opendilab/DI-engine/blob/main/ding/policy/acer.py) | ding -m serial -c cartpole_acer_config.py -s 0 |
+| 15 | [IMPALA](https://arxiv.org/abs/1802.01561) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IMPALA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/impala.html)
[policy/impala](https://github.com/opendilab/DI-engine/blob/main/ding/policy/impala.py) | ding -m serial -c cartpole_impala_config.py -s 0 |
+| 16 | [DDPG](https://arxiv.org/pdf/1509.02971.pdf)/[PADDPG](https://arxiv.org/pdf/1511.04143.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [DDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)
[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c pendulum_ddpg_config.py -s 0 |
+| 17 | [TD3](https://arxiv.org/pdf/1802.09477.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [TD3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3.html)
[policy/td3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3.py) | python3 -u pendulum_td3_main.py / ding -m serial -c pendulum_td3_config.py -s 0 |
+| 18 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [D4PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/d4pg.html)
[policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
+| 19 | [SAC](https://arxiv.org/abs/1801.01290)/[MASAC] | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [SAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sac.html)
[policy/sac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sac.py) | ding -m serial -c pendulum_sac_config.py -s 0 |
+| 20 | [PDQN](https://arxiv.org/pdf/1810.06394.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_pdqn_config.py -s 0 |
+| 21 | [MPDQN](https://arxiv.org/pdf/1905.04388.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_mpdqn_config.py -s 0 |
+| 22 | [HPPO](https://arxiv.org/pdf/1903.01344.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | ding -m serial_onpolicy -c gym_hybrid_hppo_config.py -s 0 |
+| 23 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py |
+| 24 | [MDQN](https://arxiv.org/abs/2007.14430) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/mdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mdqn.py) | python3 -u asterix_mdqn_config.py |
+| 25 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [QMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qmix.html)
[policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 |
+| 26 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [COMA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/coma.html)
[policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 |
+| 27 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 |
+| 28 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [WQMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/wqmix.html)
[policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 |
+| 29 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [CollaQ doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/collaq.html)
[policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
+| 30 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)
[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ant_maddpg_config.py -s 0 |
+| 31 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [GAIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/gail.html)
[reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
+| 32 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [SQIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sqil.html)
[entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
+| 33 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [DQFD doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqfd.html)
[policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
+| 34 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [R2D3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d3.html)
[R2D3中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html)
[policy/r2d3](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html) | python3 -u pong_r2d3_r2d2expert_config.py |
+| 35 | [Guided Cost Learning](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [Guided Cost Learning中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/guided_cost_zh.html)
[reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py |
+| 36 | [TREX](https://arxiv.org/abs/1904.06387) | ![IL](https://img.shields.io/badge/-IL-purple) | [TREX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/trex.html)
[reward_model/trex](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/trex_reward_model.py) | python3 mujoco_trex_main.py |
+| 37 | [Implicit Behavorial Cloning](https://implicitbc.github.io/) (DFO+MCMC) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ibc.py)
[model/template/ebm](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/ebm.py) | python3 d4rl_ibc_main.py -s 0 -c pen_human_ibc_mcmc_config.py |
+| 38 | [BCO](https://arxiv.org/pdf/1805.01954.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/bco](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_bco.py) | python3 -u cartpole_bco_config.py |
+| 39 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [HER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/her.html)
[reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
+| 40 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [RND doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rnd.html)
[reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_rnd_onppo_config.py |
+| 41 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [ICM doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/icm.html)
[ICM中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/icm_zh.html)
[reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
+| 42 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [CQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/cql.html)
[policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
+| 43 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)
[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py |
+| 44 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/decision_transformer.py) | python3 -u d4rl_dt_main.py |
+| 45 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)
[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
+| 46 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
+| 47 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
+| 48 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)
[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
+| 49 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
+| 50 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
+| 51 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
+| 52 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
+| 53 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
+| 54 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)
[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
+| 55 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
@@ -299,7 +300,8 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/acrobot_zh.html) |
| 34 |[box2d/car_racing](https://github.com/openai/gym/blob/master/gym/envs/box2d/car_racing.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)
![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/carracing/car_racing.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/carracing/envs)
环境指南 |
| 35 |[metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/metadrive_zh.html) |
-| 36 |[cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs)
环境指南 |
+| 36 |tabmwp | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.PNG) | |
+| 37 |[cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs)
环境指南 |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
@@ -377,7 +379,7 @@ DI-engine utilizes [TreeTensor](https://github.com/opendilab/DI-treetensor) as t
```diff
import torch
import treetensor.torch as ttorch
-
+
B = 4
@@ -410,7 +412,7 @@ DI-engine utilizes [TreeTensor](https://github.com/opendilab/DI-treetensor) as t
- stacked_data = stack(data, dim=0)
+ data = [ttorch.tensor(d) for d in data]
+ stacked_data = ttorch.stack(data, dim=0)
-
+
# validate
- assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
- assert stacked_data['action'].shape == (B, 1)
diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py
index 59fdfaac82..a291936d64 100755
--- a/ding/model/template/__init__.py
+++ b/ding/model/template/__init__.py
@@ -5,6 +5,7 @@
from .vac import VAC, DREAMERVAC
from .bc import DiscreteBC, ContinuousBC
from .pg import PG
+from .language_transformer import LanguageTransformer
# algorithm-specific
from .ppg import PPG
from .qmix import Mixer, QMix
diff --git a/ding/model/template/language_transformer.py b/ding/model/template/language_transformer.py
new file mode 100644
index 0000000000..d9541bcf82
--- /dev/null
+++ b/ding/model/template/language_transformer.py
@@ -0,0 +1,63 @@
+import torch
+
+from ding.utils import MODEL_REGISTRY
+from torch import nn
+try:
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
+except ImportError:
+ import sys
+ from ditk import logging
+ logging.warning("not found transformer, please install it using: pip install transformers")
+ sys.exit(1)
+
+
+@MODEL_REGISTRY.register('language_transformer')
+class LanguageTransformer(nn.Module):
+
+ def __init__(
+ self,
+ model_name: str = "bert-base-uncased",
+ add_linear: bool = False,
+ embedding_size: int = 128,
+ freeze_encoder: bool = True
+ ) -> None:
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+ self.model = AutoModelForTokenClassification.from_pretrained(model_name)
+
+ # Freeze transformer encoder and only train the linear layer
+ if freeze_encoder:
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ if add_linear:
+ # Add an additional small, adjustable linear layer on top of BERT tuned through RL
+ self.embedding_size = embedding_size
+ self.linear = nn.Linear(
+ self.model.config.hidden_size, embedding_size
+ ) # 768 for bert-base-uncased, distilbert-base-uncased
+ else:
+ self.linear = None
+
+ def _calc_embedding(self, x: list) -> torch.Tensor:
+ # ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer,
+ # the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach
+ # the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is
+ # exactly ``max_length``, which can enable batch-wise computing.
+ input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device)
+ output = self.model(**input, output_hidden_states=True)
+ # Get last layer hidden states
+ last_hidden_states = output.hidden_states[-1]
+ # Get [CLS] hidden states
+ sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size
+
+ if self.linear:
+ sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size
+
+ return sentence_embedding
+
+ def forward(self, train_samples: list, candidate_samples: list) -> dict:
+ prompt_embedding = self._calc_embedding(train_samples)
+ cands_embedding = self._calc_embedding(candidate_samples)
+ scores = torch.mm(prompt_embedding, cands_embedding.t())
+ return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}
diff --git a/ding/model/template/tests/test_language_transformer.py b/ding/model/template/tests/test_language_transformer.py
new file mode 100644
index 0000000000..40095c2ab2
--- /dev/null
+++ b/ding/model/template/tests/test_language_transformer.py
@@ -0,0 +1,25 @@
+import pytest
+
+from ding.model.template.language_transformer import LanguageTransformer
+
+
+@pytest.mark.unittest
+class TestNLPPretrainedModel:
+
+ def check_model(self):
+ test_pids = [1]
+ cand_pids = [0, 2, 4]
+ problems = [
+ "This is problem 0", "This is the first question", "Second problem is here", "Another problem",
+ "This is the last problem"
+ ]
+ ctxt_list = [problems[pid] for pid in test_pids]
+ cands_list = [problems[pid] for pid in cand_pids]
+
+ model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
+ scores = model(ctxt_list, cands_list)
+ assert scores.shape == (1, 3)
+
+ model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
+ scores = model(ctxt_list, cands_list)
+ assert scores.shape == (1, 3)
diff --git a/ding/model/wrapper/model_wrappers.py b/ding/model/wrapper/model_wrappers.py
index 9da90bcaa9..38c6346f2e 100644
--- a/ding/model/wrapper/model_wrappers.py
+++ b/ding/model/wrapper/model_wrappers.py
@@ -411,6 +411,53 @@ def forward(self, *args, **kwargs):
return output
+class CombinationArgmaxSampleWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Used to help the model to sample combination argmax action.
+ """
+
+ def forward(self, shot_number, *args, **kwargs):
+ output = self._model.forward(*args, **kwargs)
+ # Generate actions.
+ act = []
+ mask = torch.zeros_like(output['logit'])
+ for ii in range(shot_number):
+ masked_logit = output['logit'] + mask
+ actions = masked_logit.argmax(dim=-1)
+ act.append(actions)
+ for jj in range(actions.shape[0]):
+ mask[jj][actions[jj]] = -1e8
+ # `act` is shaped: (B, shot_number)
+ act = torch.stack(act, dim=1)
+ output['action'] = act
+ return output
+
+
+class CombinationMultinomialSampleWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Used to help the model to sample combination multinomial action.
+ """
+
+ def forward(self, shot_number, *args, **kwargs):
+ output = self._model.forward(*args, **kwargs)
+ # Generate actions.
+ act = []
+ mask = torch.zeros_like(output['logit'])
+ for ii in range(shot_number):
+ dist = torch.distributions.Categorical(logits=output['logit'] + mask)
+ actions = dist.sample()
+ act.append(actions)
+ for jj in range(actions.shape[0]):
+ mask[jj][actions[jj]] = -1e8
+
+ # `act` is shaped: (B, shot_number)
+ act = torch.stack(act, dim=1)
+ output['action'] = act
+ return output
+
+
class HybridArgmaxSampleWrapper(IModelWrapper):
r"""
Overview:
@@ -906,6 +953,8 @@ def __init__(self, model, teacher_cfg):
# model wrapper
'target': TargetNetworkWrapper,
'teacher': TeacherNetworkWrapper,
+ 'combination_argmax_sample': CombinationArgmaxSampleWrapper,
+ 'combination_multinomial_sample': CombinationMultinomialSampleWrapper,
}
diff --git a/ding/model/wrapper/test_model_wrappers.py b/ding/model/wrapper/test_model_wrappers.py
index 274d2801bb..93334bee00 100644
--- a/ding/model/wrapper/test_model_wrappers.py
+++ b/ding/model/wrapper/test_model_wrappers.py
@@ -549,3 +549,17 @@ def test_transformer_memory_wrapper(self):
assert sum(new_memory2[:, -16:].flatten()) != 0
assert sum(new_memory2[:, :-16].flatten()) == 0
assert torch.all(torch.eq(new_memory1[:, -8:], new_memory2[:, -16:-8]))
+
+ def test_combination_argmax_sample_wrapper(self):
+ model = model_wrap(ActorMLP(), wrapper_name='combination_argmax_sample')
+ data = {'obs': torch.randn(4, 3)}
+ output = model.forward(shot_number=2, inputs=data)
+ assert output['action'].shape == (4, )
+ assert (output['action'] >= 0).all() and (output['action'] < 64).all()
+
+ def test_combination_multinomial_sample_wrapper(self):
+ model = model_wrap(ActorMLP(), wrapper_name='combination_multinomial_sample')
+ data = {'obs': torch.randn(4, 3)}
+ output = model.forward(shot_number=2, inputs=data)
+ assert output['action'].shape == (4, )
+ assert (output['action'] >= 0).all() and (output['action'] < 64).all()
diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py
index 9789de56f6..794a7f94fa 100755
--- a/ding/policy/__init__.py
+++ b/ding/policy/__init__.py
@@ -54,3 +54,4 @@
# new-type policy
from .ppof import PPOFPolicy
+from .prompt_pg import PromptPGPolicy
diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py
old mode 100755
new mode 100644
index 99929f56ef..3ecfeb204e
--- a/ding/policy/command_mode_policy_instance.py
+++ b/ding/policy/command_mode_policy_instance.py
@@ -50,6 +50,7 @@
from .bdq import BDQPolicy
from .bcq import BCQPolicy
from .edac import EDACPolicy
+from .prompt_pg import PromptPGPolicy
class EpsCommandModePolicy(CommandModePolicy):
@@ -438,3 +439,8 @@ def _get_setting_learn(self, command_info: dict) -> dict:
def _get_setting_eval(self, command_info: dict) -> dict:
return {}
+
+
+@POLICY_REGISTRY.register('prompt_pg_command')
+class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
+ pass
diff --git a/ding/policy/prompt_pg.py b/ding/policy/prompt_pg.py
new file mode 100644
index 0000000000..ebccadb8a3
--- /dev/null
+++ b/ding/policy/prompt_pg.py
@@ -0,0 +1,206 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+
+from ding.rl_utils import get_train_sample
+from ding.torch_utils import Adam, to_device
+from ding.utils import POLICY_REGISTRY, split_data_generator
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from ..model import model_wrap
+
+
+@POLICY_REGISTRY.register('prompt_pg')
+class PromptPGPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of Prompt Policy Gradient (PromptPG) algorithm.
+ Link of the original paper: https://arxiv.org/abs/2209.14610
+ """
+ config = dict(
+ # (string) RL policy register name (refer to function "register_policy").
+ type='prompt_pg',
+ # (bool) whether to use cuda for network.
+ cuda=True,
+ # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
+ on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users
+ # (bool) whether to use deterministic action for evaluation.
+ deterministic_eval=True,
+ learn=dict(
+ # (int) the number of samples for one update.
+ batch_size=64,
+ # (float) the step size of one gradient descend.
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (float) max grad norm value.
+ grad_norm=5,
+ # (bool) whether to ignore done signal for non-termination env.
+ ignore_done=False,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ # n_episode=8,
+ # (int) trajectory unroll length
+ unroll_len=1,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0,
+ collector=dict(get_train_sample=True),
+ ),
+ eval=dict(),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'language_transformer', ['ding.model.template.language_transformer']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config, main and target models.
+ """
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._grad_norm = self._cfg.learn.grad_norm
+ self._learn_model = self._model # for compatibility
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
+ """
+ self._model.train()
+
+ return_infos = []
+ for i in range(0, len(data), self._cfg.learn.batch_size):
+ batch = default_collate(data[i:i + self._cfg.learn.batch_size])
+ if self._cuda:
+ batch = to_device(batch, self._device)
+
+ # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
+ train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"]
+ for ii in range(len(cand_samples)):
+ cand_samples[ii] = cand_samples[ii][0]
+ output = self._learn_model.forward(train_samples, cand_samples)
+ return_ = batch['return']
+
+ # calculate PG loss
+ real_act = batch['action'] # shape: (B, shot_number)
+ # Calculate loss.
+ total_policy_loss, total_entropy_loss = 0, 0
+ for ii in range(self._cfg.shot_number):
+ log_prob = output['dist'].log_prob(real_act[:, ii])
+ policy_loss = -(log_prob * return_).mean()
+ total_policy_loss += policy_loss
+ total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean()
+ total_loss = total_entropy_loss + total_policy_loss
+
+ # update
+ self._optimizer.zero_grad()
+ total_loss.backward()
+
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ list(self._learn_model.parameters()),
+ max_norm=self._grad_norm,
+ )
+ self._optimizer.step()
+
+ # only record last updates information in logger
+ return_info = {
+ 'cur_lr': self._optimizer.param_groups[0]['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': total_policy_loss.item(),
+ 'entropy_loss': total_entropy_loss.item(),
+ 'return_abs_max': return_.abs().max().item(),
+ 'grad_norm': grad_norm,
+ }
+ return_infos.append(return_info)
+ return return_infos
+
+ def _init_collect(self) -> None:
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._gamma = self._cfg.collect.discount_factor
+ self._collect_model = model_wrap(self._model, wrapper_name='combination_multinomial_sample')
+
+ def _forward_collect(self, data: dict) -> dict:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ self._model.eval()
+ with torch.no_grad():
+ # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
+ for ii in range(len(data['candidate_samples'])):
+ data['candidate_samples'][ii] = data['candidate_samples'][ii][0]
+ output = self._collect_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples'])
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ return {
+ 'obs': obs,
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+ Arguments:
+ - data (:obj:`list`): The trajectory's buffer list
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ if self._cfg.learn.ignore_done:
+ raise NotImplementedError
+
+ R = 0.
+ for i in reversed(range(len(data))):
+ R = self._gamma * R + data[i]['reward']
+ data[i]['return'] = R
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ self._eval_model = model_wrap(self._model, wrapper_name='combination_argmax_sample')
+
+ def _forward_eval(self, data: dict) -> dict:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ self._model.eval()
+ with torch.no_grad():
+ # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
+ for ii in range(len(data['candidate_samples'])):
+ data['candidate_samples'][ii] = data['candidate_samples'][ii][0]
+ output = self._eval_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples'])
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm']
diff --git a/dizoo/tabmwp/README.md b/dizoo/tabmwp/README.md
new file mode 100644
index 0000000000..410aed8e6f
--- /dev/null
+++ b/dizoo/tabmwp/README.md
@@ -0,0 +1,16 @@
+## TabMWP Env
+
+## Dataset
+
+The **TabMWP** dataset contains 38,431 tabular math word problems. Each question in **TabMWP** is aligned with a tabular context, which is presented as an image, semi-structured text, and a structured table. There are two types of questions: *free-text* and *multi-choice*, and each problem is annotated with gold solutions to reveal the multi-step reasoning process.
+
+The environment is described in the paper [Dynamic Prompt Learning via Policy Gradient for Semi-structured Mathematical Reasoning](https://arxiv.org/abs/2209.14610) by Pan Lu, Liang Qiu, Kai-Wei Chang, Ying Nian Wu, Song-Chun Zhu, Tanmay Rajpurohit, Peter Clark, Ashwin Kalyan, 2023.
+
+You can find more details in [Prompt PG](https://github.com/lupantech/PromptPG)
+
+## Benchmark
+
+- We collect the responses of GPT-3 using a reduced dataset with 80 training samples and 16 candidates. In this way, there is no need for users to interact with GPT-3 using the API-key of openai.
+- You can directly reproduce the benchmark by running ``python dizoo/tabmwp/configs/tabmwp_pg_config.py``
+
+![origin](./benchmark.png)
diff --git a/dizoo/tabmwp/__init__.py b/dizoo/tabmwp/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/dizoo/tabmwp/benchmark.png b/dizoo/tabmwp/benchmark.png
new file mode 100644
index 0000000000..9dfccc56c0
Binary files /dev/null and b/dizoo/tabmwp/benchmark.png differ
diff --git a/dizoo/tabmwp/config/tabmwp_pg_config.py b/dizoo/tabmwp/config/tabmwp_pg_config.py
new file mode 100644
index 0000000000..acda7bcdbd
--- /dev/null
+++ b/dizoo/tabmwp/config/tabmwp_pg_config.py
@@ -0,0 +1,66 @@
+from easydict import EasyDict
+
+tabmwp_prompt_pg_config = dict(
+ exp_name='tabmwp_prompt_pg_seed0',
+ env=dict(
+ collector_env_num=1,
+ evaluator_env_num=1,
+ n_evaluator_episode=1,
+ stop_value=1,
+ cand_number=16,
+ train_number=80,
+ engine='text-davinci-002',
+ temperature=0.,
+ max_tokens=512,
+ top_p=1.,
+ frequency_penalty=0.,
+ presence_penalty=0.,
+ option_inds=["A", "B", "C", "D", "E", "F"],
+ # The API-key of openai. You can get your key in this website: https://platform.openai.com/
+ api_key='',
+ enable_replay=True,
+ prompt_format='TQ-A',
+ seed=0,
+ ),
+ policy=dict(
+ cuda=True,
+ shot_number=2,
+ model=dict(
+ model_name="bert-base-uncased",
+ add_linear=True,
+ freeze_encoder=True,
+ embedding_size=128,
+ ),
+ learn=dict(
+ batch_size=10,
+ # (bool) Whether to normalize advantage. Default to False.
+ learning_rate=0.001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ entropy_weight=0.001,
+ weight_decay=5e-3,
+ grad_norm=0.5,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model 1 times
+ n_sample=20,
+ discount_factor=0.,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+main_config = EasyDict(tabmwp_prompt_pg_config)
+
+tabmwp_prompt_pg_config = dict(
+ env=dict(
+ type='tabmwp',
+ import_names=['dizoo.tabmwp.envs.tabmwp_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='prompt_pg'),
+ replay_buffer=dict(type='naive'),
+)
+create_config = EasyDict(tabmwp_prompt_pg_config)
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/dizoo/tabmwp/envs/__init__.py b/dizoo/tabmwp/envs/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/dizoo/tabmwp/envs/tabmwp_env.py b/dizoo/tabmwp/envs/tabmwp_env.py
new file mode 100644
index 0000000000..4da1fdfe98
--- /dev/null
+++ b/dizoo/tabmwp/envs/tabmwp_env.py
@@ -0,0 +1,244 @@
+import os
+from functools import lru_cache
+
+import gym
+import openai
+import numpy as np
+
+from ding.utils import ENV_REGISTRY
+from ding.envs import BaseEnv, BaseEnvTimestep
+from dizoo.tabmwp.envs.utils import create_example_from_pid, build_prompt, get_gpt3_output, calc_rwkv, calc_internlm,\
+ extract_prediction, normalize_answer, load_data
+
+
+@ENV_REGISTRY.register('tabmwp')
+class TabMWP(BaseEnv):
+ model = None
+ tokenizer = None
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.enable_replay = cfg.enable_replay
+ self._init_flag = False
+ self.problems, self.cand_pids, self.train_pids = None, None, None
+ self.problem_id = 0
+ self.cand_examples = []
+ openai.api_key = cfg.api_key
+ self.observation_space = gym.spaces.Dict()
+ self.action_space = gym.spaces.Discrete(self.cfg.cand_number * (self.cfg.cand_number - 1))
+ self.reward_space = gym.spaces.Box(
+ low=-1, high=1, shape=(1,), dtype=np.float32
+ )
+ self.correct_num = 0
+
+ # Initialize language model if needed.
+ assert self.cfg.engine in ['text-davinci-002', 'glm-10B', 'rwkv-7B', 'internlm-7B']
+
+ try:
+ if self.cfg.engine == 'glm-10B' and TabMWP.model is None:
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+ TabMWP.tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-10b", trust_remote_code=True)
+ model = AutoModelForSeq2SeqLM.from_pretrained("THUDM/glm-10b", trust_remote_code=True)
+ TabMWP.model = model.half()
+ elif self.cfg.engine == 'rwkv-7B' and TabMWP.model is None:
+ from transformers import AutoTokenizer, RwkvForCausalLM
+ TabMWP.tokenizer = AutoTokenizer.from_pretrained("sgugger/rwkv-7b-pile", trust_remote_code=True)
+ model = RwkvForCausalLM.from_pretrained("sgugger/rwkv-7b-pile")
+ TabMWP.model = model.half()
+ elif self.cfg.engine == 'internlm-7B' and TabMWP.model is None:
+ from transformers import AutoTokenizer, AutoModelForCausalLM
+ TabMWP.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-7b", trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained("internlm/internlm-7b", trust_remote_code=True)
+ TabMWP.model = model.eval()
+ except ImportError:
+ import sys
+ from ditk import logging
+ logging.warning("not found transformer, please install it using: pip install transformers")
+ sys.exit(1)
+
+ @lru_cache(maxsize=10000)
+ def get_output(self, inp: str) -> str:
+ inputs = TabMWP.tokenizer(inp + " [MASK].", return_tensors="pt")
+ inputs = TabMWP.tokenizer.build_inputs_for_generation(inputs, max_gen_length=512)
+ inputs = {key: value.cuda() for key, value in inputs.items()}
+ outputs = TabMWP.model.generate(**inputs, max_length=512, eos_token_id=TabMWP.tokenizer.eop_token_id,
+ pad_token_id=TabMWP.tokenizer.eos_token_id)
+ outputs = TabMWP.tokenizer.decode(outputs[0].tolist())
+
+ t0 = outputs.find('<|startofpiece|>') + 16
+ t1 = outputs.find('<|endofpiece|>')
+
+ return outputs[t0:t1]
+
+ def seed(self, seed: int, dynamic_seed: bool = False) -> None:
+ self.cfg.seed = seed
+
+ def reset(self) -> dict:
+ self.problems, self.cand_pids, self.train_pids = load_data(self.cfg)
+ if TabMWP.model is not None:
+ TabMWP.model = TabMWP.model.cuda()
+ if self.enable_replay:
+ self.cand_pids = ['32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514',
+ '19270', '23713', '17209', '33379', '34987', '11177']
+ if self.cfg.seed == 0: # train
+ self.train_pids = ['14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433',
+ '26979', '18135', '13347', '17679', '38426', '3454', '10432', '31011', '12162',
+ '13063', '7812', '29661', '24482', '4970', '4405', '17405', '27781', '26724', '5993',
+ '16442', '30148', '15895', '6855', '29903', '18107', '29504', '11106', '32964',
+ '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455',
+ '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198',
+ '26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520',
+ '37329', '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229',
+ '22918', '31680', '15024', '24607', '26930']
+ model_io_path = 'dizoo/tabmwp/data/model_in_out_train.txt'
+ if not os.path.exists(model_io_path):
+ os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_train.txt -O '
+ + model_io_path + ' --no-check-certificate')
+ else:
+ self.train_pids = ['21037', '22976', '2224', '14145', '27962', '26553', '22110', '16541', '26044',
+ '19492', '31882', '11991', '27594', '7637', '15394', '7666', '5177', '33761',
+ '13703', '29105']
+ model_io_path = 'dizoo/tabmwp/data/model_in_out_eval.txt'
+ os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_eval.txt -O '
+ + model_io_path + ' --no-check-certificate')
+
+ self.cfg.cand_number = len(self.cand_pids)
+ self.cfg.train_number = len(self.train_pids)
+
+ self.results_memory = []
+ with open(model_io_path, encoding="ISO-8859-1") as f:
+ tmp = f.read().split('\n')
+ for tt in tmp:
+ if len(tt.strip()) == 0:
+ continue
+ self.results_memory.append(eval(tt))
+
+ self.cand_examples = []
+ self.correct_num = 0
+ for pid in self.cand_pids:
+ example = create_example_from_pid(pid, self.problems, self.cfg, test=True)
+ self.cand_examples.append(example)
+
+ self._init_flag = True
+ self.problem_id = 0
+ train_sample = create_example_from_pid(self.train_pids[self.problem_id], self.problems, self.cfg, test=True)
+ obs = {'train_sample': train_sample, 'candidate_samples': self.cand_examples}
+ return obs
+
+ def search_answer(self, pid, pids):
+ for item in self.results_memory:
+ if item['pid'] != pid:
+ continue
+ if item['shot_pids'] == pids:
+ return item['output']
+
+ raise ValueError('item does not exists.')
+
+ def parse_all_answers(self):
+ self.cand_pids = ['32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713', '17209', '33379', '34987', '11177', '30218', '26066', '24169', '28492']
+ self.train_pids = ['14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135', '13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482', '4970', '4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903', '18107', '29504', '11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455', '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', '26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520', '37329', '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229', '22918', '31680', '15024', '24607', '26930']
+ self.problem_id = 0
+ self.cfg.train_number = len(self.train_pids)
+ n = len(self.cand_pids)
+
+ with open('sampled_pid.txt', 'w') as f:
+ f.write(str(self.cand_pids) + '\n')
+ f.write(str(self.train_pids) + '\n')
+
+ with open('model_in_out.txt', 'w') as f:
+ while self.problem_id < self.cfg.train_number:
+ for i in range(n):
+ for j in range(n):
+ if i == j:
+ continue
+ shot_pids = [self.cand_pids[i], self.cand_pids[j]]
+ pid = self.train_pids[self.problem_id]
+
+ # generate the prompt input
+ prompt = build_prompt(self.problems, shot_pids, pid, self.cfg)
+
+ # get the output from LM
+ # assert self._args.engine == 'text-davinci-002'
+ output = get_gpt3_output(prompt, self.cfg)
+
+ output_txt = {'shot_pids': shot_pids, 'pid': pid, 'prompt': prompt, 'output': output}
+ f.write(str(output_txt) + '\n')
+ print(self.problem_id, i, j)
+
+ self.problem_id += 1
+
+ def close(self) -> None:
+ self._init_flag = False
+
+ def step(self, action: np.array) -> BaseEnvTimestep:
+ shot_pids = [self.cand_pids[cid] for cid in action]
+ pid = self.train_pids[self.problem_id]
+
+ # generate the prompt input
+ prompt = build_prompt(self.problems, shot_pids, pid, self.cfg)
+
+ # get the output from LM
+ if self.enable_replay:
+ output = self.search_answer(pid, shot_pids)
+ elif self.cfg.engine == 'text-davinci-002':
+ output = get_gpt3_output(prompt, self.cfg)
+ elif self.cfg.engine == 'rwkv-7B':
+ output = calc_rwkv(self.model, self.tokenizer, prompt)
+ elif self.cfg.engine == 'internlm-7B':
+ output = calc_internlm(self.model, self.tokenizer, prompt, self.cfg)
+ else:
+ output = self.get_output(prompt)
+
+ # extract the prediction from the output
+ prediction = extract_prediction(output, self.problems[pid]['choices'], self.cfg.option_inds)
+
+ # normalize the number in the text
+ prediction_norm = normalize_answer(prediction, self.problems[pid]['unit'])
+
+ if prediction_norm.lower() == normalize_answer(self.problems[pid]['answer'],
+ self.problems[pid]['unit']).lower():
+ reward = 1
+ self.correct_num += 1
+ else:
+ reward = -1
+
+ self.problem_id += 1
+ if self.problem_id == self.cfg.train_number:
+ done = True
+ info = {'eval_episode_return': self.correct_num / self.cfg.train_number}
+ else:
+ done = False
+ info = {}
+
+ train_sample = create_example_from_pid(pid, self.problems, self.cfg, test=True)
+ obs = {'train_sample': train_sample, 'candidate_samples': self.cand_examples}
+
+ return BaseEnvTimestep(obs, reward, done, info)
+
+ def __repr__(self) -> str:
+ return "DI-engine tabmwp Env"
+
+
+if __name__ == '__main__':
+ from easydict import EasyDict
+ env_cfg = EasyDict(dict(
+ cand_number=16,
+ train_number=20,
+ engine='text-davinci-002',
+ temperature=0.,
+ max_tokens=512,
+ top_p=1.,
+ frequency_penalty=0.,
+ presence_penalty=0.,
+ option_inds=["A", "B", "C", "D", "E", "F"],
+ api_key='xxx',
+ prompt_format='TQ-A',
+ enable_replay=True,
+ seed=0,
+ ))
+ env = TabMWP(env_cfg)
+ env.seed(0)
+ env.reset()
+ env.parse_all_answers()
+ env.search_answer('22976', ['32889', '8044'])
+
diff --git a/dizoo/tabmwp/envs/test_tabmwp_env.py b/dizoo/tabmwp/envs/test_tabmwp_env.py
new file mode 100644
index 0000000000..ca9020d971
--- /dev/null
+++ b/dizoo/tabmwp/envs/test_tabmwp_env.py
@@ -0,0 +1,25 @@
+from easydict import EasyDict
+import pytest
+from dizoo.tabmwp.envs.tabmwp_env import TabMWP
+
+
+@pytest.mark.envtest
+class TestSokoban:
+
+ def test_tabmwp(self):
+ config = dict(
+ cand_number=20,
+ train_number=100,
+ engine='text-davinci-002',
+ temperature=0.,
+ max_tokens=512,
+ top_p=1.,
+ frequency_penalty=0.,
+ presence_penalty=0.,
+ option_inds=["A", "B", "C", "D", "E", "F"],
+ api_key='',
+ )
+ config = EasyDict(config)
+ env = TabMWP(config)
+ env.seed(0)
+ env.close()
diff --git a/dizoo/tabmwp/envs/utils.py b/dizoo/tabmwp/envs/utils.py
new file mode 100644
index 0000000000..f1f74a3f0c
--- /dev/null
+++ b/dizoo/tabmwp/envs/utils.py
@@ -0,0 +1,335 @@
+import json
+import os
+import random
+import re
+import time
+from functools import lru_cache
+import torch
+
+import numpy as np
+import openai
+try:
+ import transformers
+except ImportError:
+ import sys
+ from ditk import logging
+ logging.warning("not found transformer, please install it using: pip install transformers")
+ sys.exit(1)
+
+
+def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8) -> int:
+ # Sample an action given the logits.
+ probs = torch.softmax(out, dim=-1).cpu().numpy()
+ sorted_probs = np.sort(probs)[::-1]
+ cumulative_probs = np.cumsum(sorted_probs)
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
+ probs[probs < cutoff] = 0
+ if temperature != 1.0:
+ probs = probs.pow(1.0 / temperature)
+ probs = probs / np.sum(probs)
+ out = np.random.choice(a=len(probs), p=probs)
+ return out
+
+
+def calc_rwkv(model: transformers.RwkvForCausalLM, tokenizer: transformers.AutoTokenizer, prompt: str, max_len: int = 10) -> str:
+ # Use RWKV to generate sentence.
+ orig_len = len(prompt)
+ inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
+ outputs = model(**inputs, labels=inputs["input_ids"])
+ out, state = outputs.logits, outputs.state
+ # Recurrent generation.
+ with torch.no_grad():
+ for i in range(max_len):
+ token = sample_logits(out[0, -1])
+ tmp = tokenizer.decode([token])
+ prompt = prompt + tmp
+ inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
+ outputs = model(**inputs, labels=inputs["input_ids"])
+ out, state = outputs.logits, outputs.state
+ return prompt[orig_len:]
+
+
+def calc_internlm(model, tokenizer, prompt: str, args):
+ inputs = tokenizer(prompt, return_tensors="pt")
+ for k, v in inputs.items():
+ inputs[k] = v.cuda()
+ gen_kwargs = {"max_length": args.max_tokens, "top_p": args.top_p, "temperature": args.temperature, "do_sample": True,
+ "repetition_penalty": args.frequency_penalty}
+ output = model.generate(**inputs, **gen_kwargs)
+ output = tokenizer.decode(output)
+ return output
+
+
+def load_data(args: dict) -> tuple:
+ # Load tabmwp dataset.
+ random.seed(args.seed)
+ data_root = 'dizoo/tabmwp/data'
+
+ if not os.path.exists(data_root):
+ os.mkdir(data_root)
+
+ if not os.path.exists(os.path.join(data_root, f'problems_train.json')):
+ os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/problems_train.json -O '
+ + os.path.join(data_root, f'problems_train.json') + ' --no-check-certificate')
+ problems = json.load(open(os.path.join(data_root, f'problems_train.json')))
+
+ pids = list(problems.keys())
+ samples = random.sample(pids, args.train_number + args.cand_number) # random sample
+ train_pids = samples[:args.train_number]
+ cand_pids = samples[args.train_number:]
+ return problems, cand_pids, train_pids
+
+
+def get_gpt3_output(prompt: str, args: dict) -> str:
+ return call_gpt3(args.engine, prompt, args.temperature, args.max_tokens, args.top_p, args.frequency_penalty,
+ args.presence_penalty)
+
+
+@lru_cache(maxsize=10000)
+def call_gpt3(engine: str, prompt: str, temperature: float, max_tokens: int, top_p: float,
+ frequency_penalty: float, presence_penalty: float) -> str:
+ patience = 100
+ while True:
+ try:
+ response = openai.Completion.create(engine=engine,
+ prompt=prompt,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ top_p=top_p,
+ frequency_penalty=frequency_penalty,
+ presence_penalty=presence_penalty,
+ stop=["\n"])
+ output = response["choices"][0]["text"].strip()
+ break
+ except Exception:
+ patience -= 1
+ if not patience:
+ print("!!! running out of patience waiting for OpenAI")
+ else:
+ time.sleep(0.1)
+ return output
+
+
+def get_table_text(problem: dict) -> str:
+ table = problem['table']
+ title = problem['table_title']
+ if title and len(title) > 0:
+ table = f"[TITLE]: {title}\n{table}"
+ return table
+
+
+def get_question_text(problem: dict, option_inds: list) -> str:
+ question = problem['question']
+
+ unit = problem['unit']
+ if unit and len(unit) > 0:
+ question = f"{question} (Unit: {unit})"
+
+ choices = problem['choices']
+ if choices and len(choices) > 0:
+ choice_list = []
+ for i, c in enumerate(choices):
+ choice_list.append("({}) {}".format(option_inds[i], c))
+ options = " ".join(choice_list)
+ question = f"{question}\nOptions: {options}"
+
+ return question
+
+
+def get_answer(problem: dict) -> str:
+ return problem['answer']
+
+
+def get_solution_text(problem: dict) -> str:
+ # GPT-3 can generate the solution with more tokens
+ solution = problem['solution'].replace("\n", "\\n")
+ return solution
+
+
+def create_one_example(format: str, table: str, question: str, answer: str,
+ solution: str, test_example:bool = True) -> str:
+ # Using template to generate one prompt example.
+ input_format, output_format = format.split("-") # e.g., "TQ-A"
+
+ elements = {
+ "Q": f"Question: {question}",
+ "T": f"Table: {table}",
+ "S": f"Solution: {solution}",
+ "A": f"Answer: The answer is {answer}.",
+ "AS": f"Answer: The answer is {answer}. BECAUSE: {solution}",
+ "SA": f"Answer: {solution} The answer is {answer}."
+ }
+
+ # Input
+ input = "\n".join(elements[label] for label in input_format)
+
+ # Output
+ if test_example:
+ output = "Answer:"
+ else:
+ output = elements[output_format]
+
+ # Prompt text
+ text = input + "\n" + output
+ text = text.replace(" ", " ").strip()
+
+ return text
+
+
+def build_prompt(problems: list, shot_pids: list, test_pid: int, args: dict) -> str:
+ # Given ids, generate the complete prompt. That is, the input to LM.
+ examples = []
+ pids = shot_pids + [test_pid]
+
+ # n-shot training examples
+ for pid in pids:
+ problem = problems[pid]
+ table = get_table_text(problem)
+ question = get_question_text(problem, args.option_inds)
+ answer = get_answer(problem)
+ solution = get_solution_text(problems[pid])
+
+ if pid == test_pid:
+ assert pid not in shot_pids
+ example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True)
+ else:
+ example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False)
+
+ examples.append(example)
+
+ # create the prompt input
+ prompt_input = '\n\n'.join(examples)
+
+ return prompt_input
+
+
+def extract_prediction(output: str, options: list, option_inds: list) -> str:
+ idx = output.find('\n')
+ if idx > 0:
+ output = output[:idx]
+ idx = output.find('=')
+ if idx > 0:
+ output = output[idx + 1:].strip()
+ # $\\frac{16}{95}$ -> 16/95
+ output = re.sub(r"\$?\\frac\{([\d\.\,\-]+)\}\{([\d\.\,]+)\}\$?", r"\1/\2", output)
+
+ output = re.sub(r"(? 0:
+ pred = res[0].upper() # e.g., "B"
+ if pred in option_inds:
+ ind = option_inds.index(pred) # 1
+ if ind >= len(options):
+ ind = random.choice(range(len(options)))
+ predition = options[ind]
+ return predition
+
+ # find the most similar options
+ scores = [score_string_similarity(x, output) for x in options]
+ max_idx = int(np.argmax(scores)) # json does not recognize NumPy data types
+ predition = options[max_idx]
+ return predition
+
+ else:
+ # free_text QA problems, numeric answer
+ patterns = [
+ # r'^\([A-Za-z]\) ([\s\S]+)$', # "(A) XXXXX"
+ # r'[Th]he answer is \([A-Za-z]\) ([\s\S]+)$', # "The answer is (B) XXXXX."
+ r'[Th]he answer is ([\s\S]+)$', # "The answer is XXXXX.",
+ r'[Th]he table shows that ([\d\$\.\,\/\:]+) ',
+ r' = ([\d\$\.\,\/\:]+)', # "= $1.40"
+ r'(?<= be| is) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "will be $1.40"
+ r'(?<= are| was) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
+ r'(?<= were) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
+ r' ([\d\$\.\,\/\:]+ [AP]\.M\.)', # 7:25 P.M.
+ r'([\-\d\$\.\,\/\:]{0,}[\d]+)', # 14.5
+ ]
+
+ for p in patterns:
+ pattern = re.compile(p)
+ res = pattern.findall(output)
+ if len(res) > 0:
+ predition = res[-1].strip()
+ if predition.endswith(".") and ".M." not in predition:
+ predition = predition[:-1]
+ return predition
+
+ return output
+
+
+def normalize_answer(text: str, unit: str) -> str:
+ # ["1,000", "123", "3/4", "56.456", "$56.4", "-3", "-10.02", "-3/2"]
+
+ text = re.sub("^[\$]", "", text)
+ text = re.sub("[\,\.\,\/]$", "", text)
+ result = re.match("^[-+]?[\d,./]+$", text)
+
+ if result is not None:
+ # is number?
+ text = text.replace(",", "")
+ result = re.match("[-+]?\d+$", text)
+ try:
+ if result is not None:
+ number = int(text)
+ elif "/" in text:
+ nums = text.split("/")
+ number = round(float(nums[0]) / float(nums[1]), 3)
+ else:
+ number = round(float(text), 3)
+ number = str(number)
+ number = re.sub(r"\.[0]+$", "", number)
+ return number
+ except:
+ return text
+ else:
+ # is text
+ if unit:
+ text = text.replace(unit, "").strip()
+ return text
+
+
+def score_string_similarity(str1: str, str2: str) -> float:
+ if str1 == str2:
+ return 2.0
+ if " " in str1 or " " in str2:
+ str1_split = str1.split(" ")
+ str2_split = str2.split(" ")
+ overlap = list(set(str1_split) & set(str2_split))
+ return len(overlap) / max(len(str1_split), len(str2_split))
+ else:
+ if str1 == str2:
+ return 1.0
+ else:
+ return 0.0
+
+
+def create_example_from_pid(pid: int, problems: list, args: dict, test: bool = False) -> str:
+ problem = problems[pid]
+ table = get_table_text(problem)
+ question = get_question_text(problem, args.option_inds)
+ answer = get_answer(problem)
+ solution = get_solution_text(problems[pid])
+
+ if test:
+ example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True)
+ else:
+ example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False)
+
+ return example