diff --git a/.github/workflows/algo_test.yml b/.github/workflows/algo_test.yml index e61fe104aa..02f46cd9a1 100644 --- a/.github/workflows/algo_test.yml +++ b/.github/workflows/algo_test.yml @@ -31,5 +31,6 @@ jobs: run: | python -m pip install . python -m pip install ".[test,k8s]" + python -m pip install transformers ./ding/scripts/install-k8s-tools.sh make algotest diff --git a/.github/workflows/envpool_test.yml b/.github/workflows/envpool_test.yml index 66ea2f8d11..e14a4848a0 100644 --- a/.github/workflows/envpool_test.yml +++ b/.github/workflows/envpool_test.yml @@ -24,5 +24,6 @@ jobs: python -m pip install . python -m pip install ".[test,k8s]" python -m pip install ".[envpool]" + python -m pip install transformers ./ding/scripts/install-k8s-tools.sh make envpooltest diff --git a/.github/workflows/platform_test.yml b/.github/workflows/platform_test.yml index 2194ab027c..9857639b3b 100644 --- a/.github/workflows/platform_test.yml +++ b/.github/workflows/platform_test.yml @@ -25,5 +25,6 @@ jobs: run: | python -m pip install . python -m pip install ".[test,k8s]" + python -m pip install transformers python -m pip uninstall pytest-timeouts -y make platformtest diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index d58a863759..771dc596e3 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -25,6 +25,7 @@ jobs: python -m pip install box2d-py python -m pip install . python -m pip install ".[test,k8s]" + python -m pip install transformers ./ding/scripts/install-k8s-tools.sh make unittest - name: Upload coverage to Codecov @@ -53,5 +54,6 @@ jobs: run: | python -m pip install . python -m pip install ".[test,k8s]" + python -m pip install transformers ./ding/scripts/install-k8s-tools.sh make benchmark diff --git a/README.md b/README.md index a7f0e13c3b..7c26010214 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ It provides **python-first** and **asynchronous-native** task and middleware abs - [awesome-diffusion-model-in-rl](https://github.com/opendilab/awesome-diffusion-model-in-rl): A curated list of Diffusion Model in RL resources - [awesome-end-to-end-autonomous-driving](https://github.com/opendilab/awesome-end-to-end-autonomous-driving): A curated list of awesome End-to-End Autonomous Driving resources - [awesome-driving-behavior-prediction](https://github.com/opendilab/awesome-driving-behavior-prediction): A collection of research papers for Driving Behavior Prediction - + On the low-level end, DI-engine comes with a set of highly re-usable modules, including [RL optimization functions](https://github.com/opendilab/DI-engine/tree/main/ding/rl_utils), [PyTorch utilities](https://github.com/opendilab/DI-engine/tree/main/ding/torch_utils) and [auxiliary tools](https://github.com/opendilab/DI-engine/tree/main/ding/utils). @@ -210,51 +210,52 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` | 7 | [SQL](https://arxiv.org/pdf/1702.08165.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [SQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sql.html)
[policy/sql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sql.py) | ding -m serial -c cartpole_sql_config.py -s 0 | | 8 | [R2D2](https://openreview.net/forum?id=r1lyTjAqYX) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [R2D2 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d2.html)
[policy/r2d2](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d2.py) | ding -m serial -c cartpole_r2d2_config.py -s 0 | | 9 | [PG](https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)
[policy/pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pg.py) | ding -m serial -c cartpole_pg_config.py -s 0 | -| 10 | [A2C](https://arxiv.org/pdf/1602.01783.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [A2C doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)
[policy/a2c](https://github.com/opendilab/DI-engine/blob/main/ding/policy/a2c.py) | ding -m serial -c cartpole_a2c_config.py -s 0 | -| 11 | [PPO](https://arxiv.org/abs/1707.06347)/[MAPPO](https://arxiv.org/pdf/2103.01955.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [PPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppo.html)
[policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | python3 -u cartpole_ppo_main.py / ding -m serial_onpolicy -c cartpole_ppo_config.py -s 0 | -| 12 | [PPG](https://arxiv.org/pdf/2009.04416.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppg.html)
[policy/ppg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppg.py) | python3 -u cartpole_ppg_main.py | -| 13 | [ACER](https://arxiv.org/pdf/1611.01224.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [ACER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/acer.html)
[policy/acer](https://github.com/opendilab/DI-engine/blob/main/ding/policy/acer.py) | ding -m serial -c cartpole_acer_config.py -s 0 | -| 14 | [IMPALA](https://arxiv.org/abs/1802.01561) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IMPALA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/impala.html)
[policy/impala](https://github.com/opendilab/DI-engine/blob/main/ding/policy/impala.py) | ding -m serial -c cartpole_impala_config.py -s 0 | -| 15 | [DDPG](https://arxiv.org/pdf/1509.02971.pdf)/[PADDPG](https://arxiv.org/pdf/1511.04143.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [DDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)
[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c pendulum_ddpg_config.py -s 0 | -| 16 | [TD3](https://arxiv.org/pdf/1802.09477.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [TD3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3.html)
[policy/td3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3.py) | python3 -u pendulum_td3_main.py / ding -m serial -c pendulum_td3_config.py -s 0 | -| 17 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [D4PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/d4pg.html)
[policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py | -| 18 | [SAC](https://arxiv.org/abs/1801.01290)/[MASAC] | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [SAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sac.html)
[policy/sac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sac.py) | ding -m serial -c pendulum_sac_config.py -s 0 | -| 19 | [PDQN](https://arxiv.org/pdf/1810.06394.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_pdqn_config.py -s 0 | -| 20 | [MPDQN](https://arxiv.org/pdf/1905.04388.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_mpdqn_config.py -s 0 | -| 21 | [HPPO](https://arxiv.org/pdf/1903.01344.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | ding -m serial_onpolicy -c gym_hybrid_hppo_config.py -s 0 | -| 22 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py | -| 23 | [MDQN](https://arxiv.org/abs/2007.14430) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/mdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mdqn.py) | python3 -u asterix_mdqn_config.py | -| 24 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [QMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qmix.html)
[policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 | -| 25 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [COMA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/coma.html)
[policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 | -| 26 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 | -| 27 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [WQMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/wqmix.html)
[policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 | -| 28 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [CollaQ doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/collaq.html)
[policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 | -| 29 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)
[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ant_maddpg_config.py -s 0 | -| 30 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [GAIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/gail.html)
[reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 | -| 31 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [SQIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sqil.html)
[entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 | -| 32 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [DQFD doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqfd.html)
[policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 | -| 33 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [R2D3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d3.html)
[R2D3中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html)
[policy/r2d3](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html) | python3 -u pong_r2d3_r2d2expert_config.py | -| 34 | [Guided Cost Learning](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [Guided Cost Learning中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/guided_cost_zh.html)
[reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py | -| 35 | [TREX](https://arxiv.org/abs/1904.06387) | ![IL](https://img.shields.io/badge/-IL-purple) | [TREX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/trex.html)
[reward_model/trex](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/trex_reward_model.py) | python3 mujoco_trex_main.py | -| 36 | [Implicit Behavorial Cloning](https://implicitbc.github.io/) (DFO+MCMC) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ibc.py)
[model/template/ebm](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/ebm.py) | python3 d4rl_ibc_main.py -s 0 -c pen_human_ibc_mcmc_config.py | -| 37 | [BCO](https://arxiv.org/pdf/1805.01954.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/bco](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_bco.py) | python3 -u cartpole_bco_config.py | -| 38 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [HER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/her.html)
[reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py | -| 39 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [RND doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rnd.html)
[reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_rnd_onppo_config.py | -| 40 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [ICM doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/icm.html)
[ICM中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/icm_zh.html)
[reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py | -| 41 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [CQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/cql.html)
[policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py | -| 42 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)
[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py | -| 43 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dt.py) | python3 -u ding/example/dt.py | -| 44 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)
[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py | -| 45 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py | -| 46 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py | -| 47 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)
[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py | -| 48 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py | -| 49 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py | -| 50 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | -| 51 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | -| 52 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 | -| 53 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)
[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 | -| 54 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 | +| 10 | [PromptPG](https://arxiv.org/abs/2209.14610) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/prompt_pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_pg.py) | ding -m serial_onpolicy -c tabmwp_pg_config.py -s 0 | +| 11 | [A2C](https://arxiv.org/pdf/1602.01783.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [A2C doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html)
[policy/a2c](https://github.com/opendilab/DI-engine/blob/main/ding/policy/a2c.py) | ding -m serial -c cartpole_a2c_config.py -s 0 | +| 12 | [PPO](https://arxiv.org/abs/1707.06347)/[MAPPO](https://arxiv.org/pdf/2103.01955.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [PPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppo.html)
[policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | python3 -u cartpole_ppo_main.py / ding -m serial_onpolicy -c cartpole_ppo_config.py -s 0 | +| 13 | [PPG](https://arxiv.org/pdf/2009.04416.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppg.html)
[policy/ppg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppg.py) | python3 -u cartpole_ppg_main.py | +| 14 | [ACER](https://arxiv.org/pdf/1611.01224.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [ACER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/acer.html)
[policy/acer](https://github.com/opendilab/DI-engine/blob/main/ding/policy/acer.py) | ding -m serial -c cartpole_acer_config.py -s 0 | +| 15 | [IMPALA](https://arxiv.org/abs/1802.01561) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IMPALA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/impala.html)
[policy/impala](https://github.com/opendilab/DI-engine/blob/main/ding/policy/impala.py) | ding -m serial -c cartpole_impala_config.py -s 0 | +| 16 | [DDPG](https://arxiv.org/pdf/1509.02971.pdf)/[PADDPG](https://arxiv.org/pdf/1511.04143.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [DDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)
[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c pendulum_ddpg_config.py -s 0 | +| 17 | [TD3](https://arxiv.org/pdf/1802.09477.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [TD3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3.html)
[policy/td3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3.py) | python3 -u pendulum_td3_main.py / ding -m serial -c pendulum_td3_config.py -s 0 | +| 18 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [D4PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/d4pg.html)
[policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py | +| 19 | [SAC](https://arxiv.org/abs/1801.01290)/[MASAC] | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [SAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sac.html)
[policy/sac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sac.py) | ding -m serial -c pendulum_sac_config.py -s 0 | +| 20 | [PDQN](https://arxiv.org/pdf/1810.06394.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_pdqn_config.py -s 0 | +| 21 | [MPDQN](https://arxiv.org/pdf/1905.04388.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_mpdqn_config.py -s 0 | +| 22 | [HPPO](https://arxiv.org/pdf/1903.01344.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | ding -m serial_onpolicy -c gym_hybrid_hppo_config.py -s 0 | +| 23 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py | +| 24 | [MDQN](https://arxiv.org/abs/2007.14430) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/mdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mdqn.py) | python3 -u asterix_mdqn_config.py | +| 25 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [QMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qmix.html)
[policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 | +| 26 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [COMA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/coma.html)
[policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 | +| 27 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 | +| 28 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [WQMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/wqmix.html)
[policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 | +| 29 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [CollaQ doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/collaq.html)
[policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 | +| 30 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)
[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ant_maddpg_config.py -s 0 | +| 31 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [GAIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/gail.html)
[reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 | +| 32 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [SQIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sqil.html)
[entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 | +| 33 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [DQFD doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqfd.html)
[policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 | +| 34 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [R2D3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d3.html)
[R2D3中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html)
[policy/r2d3](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html) | python3 -u pong_r2d3_r2d2expert_config.py | +| 35 | [Guided Cost Learning](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [Guided Cost Learning中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/guided_cost_zh.html)
[reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py | +| 36 | [TREX](https://arxiv.org/abs/1904.06387) | ![IL](https://img.shields.io/badge/-IL-purple) | [TREX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/trex.html)
[reward_model/trex](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/trex_reward_model.py) | python3 mujoco_trex_main.py | +| 37 | [Implicit Behavorial Cloning](https://implicitbc.github.io/) (DFO+MCMC) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ibc.py)
[model/template/ebm](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/ebm.py) | python3 d4rl_ibc_main.py -s 0 -c pen_human_ibc_mcmc_config.py | +| 38 | [BCO](https://arxiv.org/pdf/1805.01954.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/bco](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_bco.py) | python3 -u cartpole_bco_config.py | +| 39 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [HER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/her.html)
[reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py | +| 40 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [RND doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rnd.html)
[reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_rnd_onppo_config.py | +| 41 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [ICM doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/icm.html)
[ICM中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/icm_zh.html)
[reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py | +| 42 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [CQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/cql.html)
[policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py | +| 43 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html)
[policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py | +| 44 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/decision_transformer.py) | python3 -u d4rl_dt_main.py | +| 45 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html)
[policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py | +| 46 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py | +| 47 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py | +| 48 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html)
[world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py | +| 49 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py | +| 50 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py | +| 51 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` | +| 52 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` | +| 53 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 | +| 54 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html)
[data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 | +| 55 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 | @@ -299,7 +300,8 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` | 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/acrobot_zh.html) | | 34 |[box2d/car_racing](https://github.com/openai/gym/blob/master/gym/envs/box2d/car_racing.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)
![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/carracing/car_racing.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/carracing/envs)
环境指南 | | 35 |[metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/metadrive_zh.html) | -| 36 |[cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs)
环境指南 | +| 36 |tabmwp | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.PNG) | | +| 37 |[cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs)
环境指南 | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space @@ -377,7 +379,7 @@ DI-engine utilizes [TreeTensor](https://github.com/opendilab/DI-treetensor) as t ```diff import torch import treetensor.torch as ttorch - + B = 4 @@ -410,7 +412,7 @@ DI-engine utilizes [TreeTensor](https://github.com/opendilab/DI-treetensor) as t - stacked_data = stack(data, dim=0) + data = [ttorch.tensor(d) for d in data] + stacked_data = ttorch.stack(data, dim=0) - + # validate - assert stacked_data['obs']['image'].shape == (B, 3, 32, 32) - assert stacked_data['action'].shape == (B, 1) diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 59fdfaac82..a291936d64 100755 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -5,6 +5,7 @@ from .vac import VAC, DREAMERVAC from .bc import DiscreteBC, ContinuousBC from .pg import PG +from .language_transformer import LanguageTransformer # algorithm-specific from .ppg import PPG from .qmix import Mixer, QMix diff --git a/ding/model/template/language_transformer.py b/ding/model/template/language_transformer.py new file mode 100644 index 0000000000..d9541bcf82 --- /dev/null +++ b/ding/model/template/language_transformer.py @@ -0,0 +1,63 @@ +import torch + +from ding.utils import MODEL_REGISTRY +from torch import nn +try: + from transformers import AutoTokenizer, AutoModelForTokenClassification +except ImportError: + import sys + from ditk import logging + logging.warning("not found transformer, please install it using: pip install transformers") + sys.exit(1) + + +@MODEL_REGISTRY.register('language_transformer') +class LanguageTransformer(nn.Module): + + def __init__( + self, + model_name: str = "bert-base-uncased", + add_linear: bool = False, + embedding_size: int = 128, + freeze_encoder: bool = True + ) -> None: + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForTokenClassification.from_pretrained(model_name) + + # Freeze transformer encoder and only train the linear layer + if freeze_encoder: + for param in self.model.parameters(): + param.requires_grad = False + + if add_linear: + # Add an additional small, adjustable linear layer on top of BERT tuned through RL + self.embedding_size = embedding_size + self.linear = nn.Linear( + self.model.config.hidden_size, embedding_size + ) # 768 for bert-base-uncased, distilbert-base-uncased + else: + self.linear = None + + def _calc_embedding(self, x: list) -> torch.Tensor: + # ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer, + # the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach + # the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is + # exactly ``max_length``, which can enable batch-wise computing. + input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device) + output = self.model(**input, output_hidden_states=True) + # Get last layer hidden states + last_hidden_states = output.hidden_states[-1] + # Get [CLS] hidden states + sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size + + if self.linear: + sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size + + return sentence_embedding + + def forward(self, train_samples: list, candidate_samples: list) -> dict: + prompt_embedding = self._calc_embedding(train_samples) + cands_embedding = self._calc_embedding(candidate_samples) + scores = torch.mm(prompt_embedding, cands_embedding.t()) + return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores} diff --git a/ding/model/template/tests/test_language_transformer.py b/ding/model/template/tests/test_language_transformer.py new file mode 100644 index 0000000000..40095c2ab2 --- /dev/null +++ b/ding/model/template/tests/test_language_transformer.py @@ -0,0 +1,25 @@ +import pytest + +from ding.model.template.language_transformer import LanguageTransformer + + +@pytest.mark.unittest +class TestNLPPretrainedModel: + + def check_model(self): + test_pids = [1] + cand_pids = [0, 2, 4] + problems = [ + "This is problem 0", "This is the first question", "Second problem is here", "Another problem", + "This is the last problem" + ] + ctxt_list = [problems[pid] for pid in test_pids] + cands_list = [problems[pid] for pid in cand_pids] + + model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256) + scores = model(ctxt_list, cands_list) + assert scores.shape == (1, 3) + + model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256) + scores = model(ctxt_list, cands_list) + assert scores.shape == (1, 3) diff --git a/ding/model/wrapper/model_wrappers.py b/ding/model/wrapper/model_wrappers.py index 9da90bcaa9..38c6346f2e 100644 --- a/ding/model/wrapper/model_wrappers.py +++ b/ding/model/wrapper/model_wrappers.py @@ -411,6 +411,53 @@ def forward(self, *args, **kwargs): return output +class CombinationArgmaxSampleWrapper(IModelWrapper): + r""" + Overview: + Used to help the model to sample combination argmax action. + """ + + def forward(self, shot_number, *args, **kwargs): + output = self._model.forward(*args, **kwargs) + # Generate actions. + act = [] + mask = torch.zeros_like(output['logit']) + for ii in range(shot_number): + masked_logit = output['logit'] + mask + actions = masked_logit.argmax(dim=-1) + act.append(actions) + for jj in range(actions.shape[0]): + mask[jj][actions[jj]] = -1e8 + # `act` is shaped: (B, shot_number) + act = torch.stack(act, dim=1) + output['action'] = act + return output + + +class CombinationMultinomialSampleWrapper(IModelWrapper): + r""" + Overview: + Used to help the model to sample combination multinomial action. + """ + + def forward(self, shot_number, *args, **kwargs): + output = self._model.forward(*args, **kwargs) + # Generate actions. + act = [] + mask = torch.zeros_like(output['logit']) + for ii in range(shot_number): + dist = torch.distributions.Categorical(logits=output['logit'] + mask) + actions = dist.sample() + act.append(actions) + for jj in range(actions.shape[0]): + mask[jj][actions[jj]] = -1e8 + + # `act` is shaped: (B, shot_number) + act = torch.stack(act, dim=1) + output['action'] = act + return output + + class HybridArgmaxSampleWrapper(IModelWrapper): r""" Overview: @@ -906,6 +953,8 @@ def __init__(self, model, teacher_cfg): # model wrapper 'target': TargetNetworkWrapper, 'teacher': TeacherNetworkWrapper, + 'combination_argmax_sample': CombinationArgmaxSampleWrapper, + 'combination_multinomial_sample': CombinationMultinomialSampleWrapper, } diff --git a/ding/model/wrapper/test_model_wrappers.py b/ding/model/wrapper/test_model_wrappers.py index 274d2801bb..93334bee00 100644 --- a/ding/model/wrapper/test_model_wrappers.py +++ b/ding/model/wrapper/test_model_wrappers.py @@ -549,3 +549,17 @@ def test_transformer_memory_wrapper(self): assert sum(new_memory2[:, -16:].flatten()) != 0 assert sum(new_memory2[:, :-16].flatten()) == 0 assert torch.all(torch.eq(new_memory1[:, -8:], new_memory2[:, -16:-8])) + + def test_combination_argmax_sample_wrapper(self): + model = model_wrap(ActorMLP(), wrapper_name='combination_argmax_sample') + data = {'obs': torch.randn(4, 3)} + output = model.forward(shot_number=2, inputs=data) + assert output['action'].shape == (4, ) + assert (output['action'] >= 0).all() and (output['action'] < 64).all() + + def test_combination_multinomial_sample_wrapper(self): + model = model_wrap(ActorMLP(), wrapper_name='combination_multinomial_sample') + data = {'obs': torch.randn(4, 3)} + output = model.forward(shot_number=2, inputs=data) + assert output['action'].shape == (4, ) + assert (output['action'] >= 0).all() and (output['action'] < 64).all() diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index 9789de56f6..794a7f94fa 100755 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -54,3 +54,4 @@ # new-type policy from .ppof import PPOFPolicy +from .prompt_pg import PromptPGPolicy diff --git a/ding/policy/command_mode_policy_instance.py b/ding/policy/command_mode_policy_instance.py old mode 100755 new mode 100644 index 99929f56ef..3ecfeb204e --- a/ding/policy/command_mode_policy_instance.py +++ b/ding/policy/command_mode_policy_instance.py @@ -50,6 +50,7 @@ from .bdq import BDQPolicy from .bcq import BCQPolicy from .edac import EDACPolicy +from .prompt_pg import PromptPGPolicy class EpsCommandModePolicy(CommandModePolicy): @@ -438,3 +439,8 @@ def _get_setting_learn(self, command_info: dict) -> dict: def _get_setting_eval(self, command_info: dict) -> dict: return {} + + +@POLICY_REGISTRY.register('prompt_pg_command') +class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy): + pass diff --git a/ding/policy/prompt_pg.py b/ding/policy/prompt_pg.py new file mode 100644 index 0000000000..ebccadb8a3 --- /dev/null +++ b/ding/policy/prompt_pg.py @@ -0,0 +1,206 @@ +from typing import List, Dict, Any, Tuple, Union +from collections import namedtuple +import torch + +from ding.rl_utils import get_train_sample +from ding.torch_utils import Adam, to_device +from ding.utils import POLICY_REGISTRY, split_data_generator +from ding.utils.data import default_collate, default_decollate +from .base_policy import Policy +from ..model import model_wrap + + +@POLICY_REGISTRY.register('prompt_pg') +class PromptPGPolicy(Policy): + r""" + Overview: + Policy class of Prompt Policy Gradient (PromptPG) algorithm. + Link of the original paper: https://arxiv.org/abs/2209.14610 + """ + config = dict( + # (string) RL policy register name (refer to function "register_policy"). + type='prompt_pg', + # (bool) whether to use cuda for network. + cuda=True, + # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same) + on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users + # (bool) whether to use deterministic action for evaluation. + deterministic_eval=True, + learn=dict( + # (int) the number of samples for one update. + batch_size=64, + # (float) the step size of one gradient descend. + learning_rate=0.001, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 + entropy_weight=0.01, + # (float) max grad norm value. + grad_norm=5, + # (bool) whether to ignore done signal for non-termination env. + ignore_done=False, + ), + collect=dict( + # (int) collect n_sample data, train model n_iteration times + # n_episode=8, + # (int) trajectory unroll length + unroll_len=1, + # ============================================================== + # The following configs is algorithm-specific + # ============================================================== + # (float) discount factor for future reward, defaults int [0, 1] + discount_factor=0, + collector=dict(get_train_sample=True), + ), + eval=dict(), + ) + + def default_model(self) -> Tuple[str, List[str]]: + return 'language_transformer', ['ding.model.template.language_transformer'] + + def _init_learn(self) -> None: + r""" + Overview: + Learn mode init method. Called by ``self.__init__``. + Init the optimizer, algorithm config, main and target models. + """ + # Optimizer + self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate) + + self._entropy_weight = self._cfg.learn.entropy_weight + self._grad_norm = self._cfg.learn.grad_norm + self._learn_model = self._model # for compatibility + + def _forward_learn(self, data: dict) -> Dict[str, Any]: + r""" + Overview: + Forward and backward function of learn mode. + Arguments: + - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward'] + Returns: + - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss. + """ + self._model.train() + + return_infos = [] + for i in range(0, len(data), self._cfg.learn.batch_size): + batch = default_collate(data[i:i + self._cfg.learn.batch_size]) + if self._cuda: + batch = to_device(batch, self._device) + + # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) + train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"] + for ii in range(len(cand_samples)): + cand_samples[ii] = cand_samples[ii][0] + output = self._learn_model.forward(train_samples, cand_samples) + return_ = batch['return'] + + # calculate PG loss + real_act = batch['action'] # shape: (B, shot_number) + # Calculate loss. + total_policy_loss, total_entropy_loss = 0, 0 + for ii in range(self._cfg.shot_number): + log_prob = output['dist'].log_prob(real_act[:, ii]) + policy_loss = -(log_prob * return_).mean() + total_policy_loss += policy_loss + total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean() + total_loss = total_entropy_loss + total_policy_loss + + # update + self._optimizer.zero_grad() + total_loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + list(self._learn_model.parameters()), + max_norm=self._grad_norm, + ) + self._optimizer.step() + + # only record last updates information in logger + return_info = { + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'total_loss': total_loss.item(), + 'policy_loss': total_policy_loss.item(), + 'entropy_loss': total_entropy_loss.item(), + 'return_abs_max': return_.abs().max().item(), + 'grad_norm': grad_norm, + } + return_infos.append(return_info) + return return_infos + + def _init_collect(self) -> None: + self._unroll_len = self._cfg.collect.unroll_len + self._gamma = self._cfg.collect.discount_factor + self._collect_model = model_wrap(self._model, wrapper_name='combination_multinomial_sample') + + def _forward_collect(self, data: dict) -> dict: + data_id = list(data.keys()) + data = default_collate(list(data.values())) + self._model.eval() + with torch.no_grad(): + # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) + for ii in range(len(data['candidate_samples'])): + data['candidate_samples'][ii] = data['candidate_samples'][ii][0] + output = self._collect_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples']) + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + r""" + Overview: + Generate dict type transition data from inputs. + Arguments: + - obs (:obj:`Any`): Env observation + - model_output (:obj:`dict`): Output of collect model, including at least ['action'] + - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \ + (here 'obs' indicates obs after env step). + Returns: + - transition (:obj:`dict`): Dict type transition data. + """ + return { + 'obs': obs, + 'action': model_output['action'], + 'reward': timestep.reward, + 'done': timestep.done, + } + + def _get_train_sample(self, data: list) -> Union[None, List[Any]]: + r""" + Overview: + Get the trajectory and the n step return data, then sample from the n_step return data + Arguments: + - data (:obj:`list`): The trajectory's buffer list + Returns: + - samples (:obj:`dict`): The training samples generated + """ + if self._cfg.learn.ignore_done: + raise NotImplementedError + + R = 0. + for i in reversed(range(len(data))): + R = self._gamma * R + data[i]['reward'] + data[i]['return'] = R + return get_train_sample(data, self._unroll_len) + + def _init_eval(self) -> None: + self._eval_model = model_wrap(self._model, wrapper_name='combination_argmax_sample') + + def _forward_eval(self, data: dict) -> dict: + data_id = list(data.keys()) + data = default_collate(list(data.values())) + self._model.eval() + with torch.no_grad(): + # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected) + for ii in range(len(data['candidate_samples'])): + data['candidate_samples'][ii] = data['candidate_samples'][ii][0] + output = self._eval_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples']) + if self._cuda: + output = to_device(output, 'cpu') + output = default_decollate(output) + return {i: d for i, d in zip(data_id, output)} + + def _monitor_vars_learn(self) -> List[str]: + return super()._monitor_vars_learn() + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm'] diff --git a/dizoo/tabmwp/README.md b/dizoo/tabmwp/README.md new file mode 100644 index 0000000000..410aed8e6f --- /dev/null +++ b/dizoo/tabmwp/README.md @@ -0,0 +1,16 @@ +## TabMWP Env + +## Dataset + +The **TabMWP** dataset contains 38,431 tabular math word problems. Each question in **TabMWP** is aligned with a tabular context, which is presented as an image, semi-structured text, and a structured table. There are two types of questions: *free-text* and *multi-choice*, and each problem is annotated with gold solutions to reveal the multi-step reasoning process. + +The environment is described in the paper [Dynamic Prompt Learning via Policy Gradient for Semi-structured Mathematical Reasoning](https://arxiv.org/abs/2209.14610) by Pan Lu, Liang Qiu, Kai-Wei Chang, Ying Nian Wu, Song-Chun Zhu, Tanmay Rajpurohit, Peter Clark, Ashwin Kalyan, 2023. + +You can find more details in [Prompt PG](https://github.com/lupantech/PromptPG) + +## Benchmark + +- We collect the responses of GPT-3 using a reduced dataset with 80 training samples and 16 candidates. In this way, there is no need for users to interact with GPT-3 using the API-key of openai. +- You can directly reproduce the benchmark by running ``python dizoo/tabmwp/configs/tabmwp_pg_config.py`` + +![origin](./benchmark.png) diff --git a/dizoo/tabmwp/__init__.py b/dizoo/tabmwp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dizoo/tabmwp/benchmark.png b/dizoo/tabmwp/benchmark.png new file mode 100644 index 0000000000..9dfccc56c0 Binary files /dev/null and b/dizoo/tabmwp/benchmark.png differ diff --git a/dizoo/tabmwp/config/tabmwp_pg_config.py b/dizoo/tabmwp/config/tabmwp_pg_config.py new file mode 100644 index 0000000000..acda7bcdbd --- /dev/null +++ b/dizoo/tabmwp/config/tabmwp_pg_config.py @@ -0,0 +1,66 @@ +from easydict import EasyDict + +tabmwp_prompt_pg_config = dict( + exp_name='tabmwp_prompt_pg_seed0', + env=dict( + collector_env_num=1, + evaluator_env_num=1, + n_evaluator_episode=1, + stop_value=1, + cand_number=16, + train_number=80, + engine='text-davinci-002', + temperature=0., + max_tokens=512, + top_p=1., + frequency_penalty=0., + presence_penalty=0., + option_inds=["A", "B", "C", "D", "E", "F"], + # The API-key of openai. You can get your key in this website: https://platform.openai.com/ + api_key='', + enable_replay=True, + prompt_format='TQ-A', + seed=0, + ), + policy=dict( + cuda=True, + shot_number=2, + model=dict( + model_name="bert-base-uncased", + add_linear=True, + freeze_encoder=True, + embedding_size=128, + ), + learn=dict( + batch_size=10, + # (bool) Whether to normalize advantage. Default to False. + learning_rate=0.001, + # (float) loss weight of the value network, the weight of policy network is set to 1 + entropy_weight=0.001, + weight_decay=5e-3, + grad_norm=0.5, + ), + collect=dict( + # (int) collect n_sample data, train model 1 times + n_sample=20, + discount_factor=0., + ), + eval=dict(evaluator=dict(eval_freq=500, )), + ), +) +main_config = EasyDict(tabmwp_prompt_pg_config) + +tabmwp_prompt_pg_config = dict( + env=dict( + type='tabmwp', + import_names=['dizoo.tabmwp.envs.tabmwp_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='prompt_pg'), + replay_buffer=dict(type='naive'), +) +create_config = EasyDict(tabmwp_prompt_pg_config) + +if __name__ == '__main__': + from ding.entry import serial_pipeline_onpolicy + serial_pipeline_onpolicy((main_config, create_config), seed=0) diff --git a/dizoo/tabmwp/envs/__init__.py b/dizoo/tabmwp/envs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/dizoo/tabmwp/envs/tabmwp_env.py b/dizoo/tabmwp/envs/tabmwp_env.py new file mode 100644 index 0000000000..4da1fdfe98 --- /dev/null +++ b/dizoo/tabmwp/envs/tabmwp_env.py @@ -0,0 +1,244 @@ +import os +from functools import lru_cache + +import gym +import openai +import numpy as np + +from ding.utils import ENV_REGISTRY +from ding.envs import BaseEnv, BaseEnvTimestep +from dizoo.tabmwp.envs.utils import create_example_from_pid, build_prompt, get_gpt3_output, calc_rwkv, calc_internlm,\ + extract_prediction, normalize_answer, load_data + + +@ENV_REGISTRY.register('tabmwp') +class TabMWP(BaseEnv): + model = None + tokenizer = None + + def __init__(self, cfg): + self.cfg = cfg + self.enable_replay = cfg.enable_replay + self._init_flag = False + self.problems, self.cand_pids, self.train_pids = None, None, None + self.problem_id = 0 + self.cand_examples = [] + openai.api_key = cfg.api_key + self.observation_space = gym.spaces.Dict() + self.action_space = gym.spaces.Discrete(self.cfg.cand_number * (self.cfg.cand_number - 1)) + self.reward_space = gym.spaces.Box( + low=-1, high=1, shape=(1,), dtype=np.float32 + ) + self.correct_num = 0 + + # Initialize language model if needed. + assert self.cfg.engine in ['text-davinci-002', 'glm-10B', 'rwkv-7B', 'internlm-7B'] + + try: + if self.cfg.engine == 'glm-10B' and TabMWP.model is None: + from transformers import AutoTokenizer, AutoModelForSeq2SeqLM + TabMWP.tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-10b", trust_remote_code=True) + model = AutoModelForSeq2SeqLM.from_pretrained("THUDM/glm-10b", trust_remote_code=True) + TabMWP.model = model.half() + elif self.cfg.engine == 'rwkv-7B' and TabMWP.model is None: + from transformers import AutoTokenizer, RwkvForCausalLM + TabMWP.tokenizer = AutoTokenizer.from_pretrained("sgugger/rwkv-7b-pile", trust_remote_code=True) + model = RwkvForCausalLM.from_pretrained("sgugger/rwkv-7b-pile") + TabMWP.model = model.half() + elif self.cfg.engine == 'internlm-7B' and TabMWP.model is None: + from transformers import AutoTokenizer, AutoModelForCausalLM + TabMWP.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-7b", trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained("internlm/internlm-7b", trust_remote_code=True) + TabMWP.model = model.eval() + except ImportError: + import sys + from ditk import logging + logging.warning("not found transformer, please install it using: pip install transformers") + sys.exit(1) + + @lru_cache(maxsize=10000) + def get_output(self, inp: str) -> str: + inputs = TabMWP.tokenizer(inp + " [MASK].", return_tensors="pt") + inputs = TabMWP.tokenizer.build_inputs_for_generation(inputs, max_gen_length=512) + inputs = {key: value.cuda() for key, value in inputs.items()} + outputs = TabMWP.model.generate(**inputs, max_length=512, eos_token_id=TabMWP.tokenizer.eop_token_id, + pad_token_id=TabMWP.tokenizer.eos_token_id) + outputs = TabMWP.tokenizer.decode(outputs[0].tolist()) + + t0 = outputs.find('<|startofpiece|>') + 16 + t1 = outputs.find('<|endofpiece|>') + + return outputs[t0:t1] + + def seed(self, seed: int, dynamic_seed: bool = False) -> None: + self.cfg.seed = seed + + def reset(self) -> dict: + self.problems, self.cand_pids, self.train_pids = load_data(self.cfg) + if TabMWP.model is not None: + TabMWP.model = TabMWP.model.cuda() + if self.enable_replay: + self.cand_pids = ['32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', + '19270', '23713', '17209', '33379', '34987', '11177'] + if self.cfg.seed == 0: # train + self.train_pids = ['14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', + '26979', '18135', '13347', '17679', '38426', '3454', '10432', '31011', '12162', + '13063', '7812', '29661', '24482', '4970', '4405', '17405', '27781', '26724', '5993', + '16442', '30148', '15895', '6855', '29903', '18107', '29504', '11106', '32964', + '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455', + '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', + '26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520', + '37329', '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229', + '22918', '31680', '15024', '24607', '26930'] + model_io_path = 'dizoo/tabmwp/data/model_in_out_train.txt' + if not os.path.exists(model_io_path): + os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_train.txt -O ' + + model_io_path + ' --no-check-certificate') + else: + self.train_pids = ['21037', '22976', '2224', '14145', '27962', '26553', '22110', '16541', '26044', + '19492', '31882', '11991', '27594', '7637', '15394', '7666', '5177', '33761', + '13703', '29105'] + model_io_path = 'dizoo/tabmwp/data/model_in_out_eval.txt' + os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_eval.txt -O ' + + model_io_path + ' --no-check-certificate') + + self.cfg.cand_number = len(self.cand_pids) + self.cfg.train_number = len(self.train_pids) + + self.results_memory = [] + with open(model_io_path, encoding="ISO-8859-1") as f: + tmp = f.read().split('\n') + for tt in tmp: + if len(tt.strip()) == 0: + continue + self.results_memory.append(eval(tt)) + + self.cand_examples = [] + self.correct_num = 0 + for pid in self.cand_pids: + example = create_example_from_pid(pid, self.problems, self.cfg, test=True) + self.cand_examples.append(example) + + self._init_flag = True + self.problem_id = 0 + train_sample = create_example_from_pid(self.train_pids[self.problem_id], self.problems, self.cfg, test=True) + obs = {'train_sample': train_sample, 'candidate_samples': self.cand_examples} + return obs + + def search_answer(self, pid, pids): + for item in self.results_memory: + if item['pid'] != pid: + continue + if item['shot_pids'] == pids: + return item['output'] + + raise ValueError('item does not exists.') + + def parse_all_answers(self): + self.cand_pids = ['32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713', '17209', '33379', '34987', '11177', '30218', '26066', '24169', '28492'] + self.train_pids = ['14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135', '13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482', '4970', '4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903', '18107', '29504', '11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455', '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', '26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520', '37329', '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229', '22918', '31680', '15024', '24607', '26930'] + self.problem_id = 0 + self.cfg.train_number = len(self.train_pids) + n = len(self.cand_pids) + + with open('sampled_pid.txt', 'w') as f: + f.write(str(self.cand_pids) + '\n') + f.write(str(self.train_pids) + '\n') + + with open('model_in_out.txt', 'w') as f: + while self.problem_id < self.cfg.train_number: + for i in range(n): + for j in range(n): + if i == j: + continue + shot_pids = [self.cand_pids[i], self.cand_pids[j]] + pid = self.train_pids[self.problem_id] + + # generate the prompt input + prompt = build_prompt(self.problems, shot_pids, pid, self.cfg) + + # get the output from LM + # assert self._args.engine == 'text-davinci-002' + output = get_gpt3_output(prompt, self.cfg) + + output_txt = {'shot_pids': shot_pids, 'pid': pid, 'prompt': prompt, 'output': output} + f.write(str(output_txt) + '\n') + print(self.problem_id, i, j) + + self.problem_id += 1 + + def close(self) -> None: + self._init_flag = False + + def step(self, action: np.array) -> BaseEnvTimestep: + shot_pids = [self.cand_pids[cid] for cid in action] + pid = self.train_pids[self.problem_id] + + # generate the prompt input + prompt = build_prompt(self.problems, shot_pids, pid, self.cfg) + + # get the output from LM + if self.enable_replay: + output = self.search_answer(pid, shot_pids) + elif self.cfg.engine == 'text-davinci-002': + output = get_gpt3_output(prompt, self.cfg) + elif self.cfg.engine == 'rwkv-7B': + output = calc_rwkv(self.model, self.tokenizer, prompt) + elif self.cfg.engine == 'internlm-7B': + output = calc_internlm(self.model, self.tokenizer, prompt, self.cfg) + else: + output = self.get_output(prompt) + + # extract the prediction from the output + prediction = extract_prediction(output, self.problems[pid]['choices'], self.cfg.option_inds) + + # normalize the number in the text + prediction_norm = normalize_answer(prediction, self.problems[pid]['unit']) + + if prediction_norm.lower() == normalize_answer(self.problems[pid]['answer'], + self.problems[pid]['unit']).lower(): + reward = 1 + self.correct_num += 1 + else: + reward = -1 + + self.problem_id += 1 + if self.problem_id == self.cfg.train_number: + done = True + info = {'eval_episode_return': self.correct_num / self.cfg.train_number} + else: + done = False + info = {} + + train_sample = create_example_from_pid(pid, self.problems, self.cfg, test=True) + obs = {'train_sample': train_sample, 'candidate_samples': self.cand_examples} + + return BaseEnvTimestep(obs, reward, done, info) + + def __repr__(self) -> str: + return "DI-engine tabmwp Env" + + +if __name__ == '__main__': + from easydict import EasyDict + env_cfg = EasyDict(dict( + cand_number=16, + train_number=20, + engine='text-davinci-002', + temperature=0., + max_tokens=512, + top_p=1., + frequency_penalty=0., + presence_penalty=0., + option_inds=["A", "B", "C", "D", "E", "F"], + api_key='xxx', + prompt_format='TQ-A', + enable_replay=True, + seed=0, + )) + env = TabMWP(env_cfg) + env.seed(0) + env.reset() + env.parse_all_answers() + env.search_answer('22976', ['32889', '8044']) + diff --git a/dizoo/tabmwp/envs/test_tabmwp_env.py b/dizoo/tabmwp/envs/test_tabmwp_env.py new file mode 100644 index 0000000000..ca9020d971 --- /dev/null +++ b/dizoo/tabmwp/envs/test_tabmwp_env.py @@ -0,0 +1,25 @@ +from easydict import EasyDict +import pytest +from dizoo.tabmwp.envs.tabmwp_env import TabMWP + + +@pytest.mark.envtest +class TestSokoban: + + def test_tabmwp(self): + config = dict( + cand_number=20, + train_number=100, + engine='text-davinci-002', + temperature=0., + max_tokens=512, + top_p=1., + frequency_penalty=0., + presence_penalty=0., + option_inds=["A", "B", "C", "D", "E", "F"], + api_key='', + ) + config = EasyDict(config) + env = TabMWP(config) + env.seed(0) + env.close() diff --git a/dizoo/tabmwp/envs/utils.py b/dizoo/tabmwp/envs/utils.py new file mode 100644 index 0000000000..f1f74a3f0c --- /dev/null +++ b/dizoo/tabmwp/envs/utils.py @@ -0,0 +1,335 @@ +import json +import os +import random +import re +import time +from functools import lru_cache +import torch + +import numpy as np +import openai +try: + import transformers +except ImportError: + import sys + from ditk import logging + logging.warning("not found transformer, please install it using: pip install transformers") + sys.exit(1) + + +def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8) -> int: + # Sample an action given the logits. + probs = torch.softmax(out, dim=-1).cpu().numpy() + sorted_probs = np.sort(probs)[::-1] + cumulative_probs = np.cumsum(sorted_probs) + cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) + probs[probs < cutoff] = 0 + if temperature != 1.0: + probs = probs.pow(1.0 / temperature) + probs = probs / np.sum(probs) + out = np.random.choice(a=len(probs), p=probs) + return out + + +def calc_rwkv(model: transformers.RwkvForCausalLM, tokenizer: transformers.AutoTokenizer, prompt: str, max_len: int = 10) -> str: + # Use RWKV to generate sentence. + orig_len = len(prompt) + inputs = tokenizer(prompt, return_tensors="pt").to('cuda') + outputs = model(**inputs, labels=inputs["input_ids"]) + out, state = outputs.logits, outputs.state + # Recurrent generation. + with torch.no_grad(): + for i in range(max_len): + token = sample_logits(out[0, -1]) + tmp = tokenizer.decode([token]) + prompt = prompt + tmp + inputs = tokenizer(prompt, return_tensors="pt").to('cuda') + outputs = model(**inputs, labels=inputs["input_ids"]) + out, state = outputs.logits, outputs.state + return prompt[orig_len:] + + +def calc_internlm(model, tokenizer, prompt: str, args): + inputs = tokenizer(prompt, return_tensors="pt") + for k, v in inputs.items(): + inputs[k] = v.cuda() + gen_kwargs = {"max_length": args.max_tokens, "top_p": args.top_p, "temperature": args.temperature, "do_sample": True, + "repetition_penalty": args.frequency_penalty} + output = model.generate(**inputs, **gen_kwargs) + output = tokenizer.decode(output) + return output + + +def load_data(args: dict) -> tuple: + # Load tabmwp dataset. + random.seed(args.seed) + data_root = 'dizoo/tabmwp/data' + + if not os.path.exists(data_root): + os.mkdir(data_root) + + if not os.path.exists(os.path.join(data_root, f'problems_train.json')): + os.system(f'wget https://opendilab.net/download/DI-zoo/tabmwp/problems_train.json -O ' + + os.path.join(data_root, f'problems_train.json') + ' --no-check-certificate') + problems = json.load(open(os.path.join(data_root, f'problems_train.json'))) + + pids = list(problems.keys()) + samples = random.sample(pids, args.train_number + args.cand_number) # random sample + train_pids = samples[:args.train_number] + cand_pids = samples[args.train_number:] + return problems, cand_pids, train_pids + + +def get_gpt3_output(prompt: str, args: dict) -> str: + return call_gpt3(args.engine, prompt, args.temperature, args.max_tokens, args.top_p, args.frequency_penalty, + args.presence_penalty) + + +@lru_cache(maxsize=10000) +def call_gpt3(engine: str, prompt: str, temperature: float, max_tokens: int, top_p: float, + frequency_penalty: float, presence_penalty: float) -> str: + patience = 100 + while True: + try: + response = openai.Completion.create(engine=engine, + prompt=prompt, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + stop=["\n"]) + output = response["choices"][0]["text"].strip() + break + except Exception: + patience -= 1 + if not patience: + print("!!! running out of patience waiting for OpenAI") + else: + time.sleep(0.1) + return output + + +def get_table_text(problem: dict) -> str: + table = problem['table'] + title = problem['table_title'] + if title and len(title) > 0: + table = f"[TITLE]: {title}\n{table}" + return table + + +def get_question_text(problem: dict, option_inds: list) -> str: + question = problem['question'] + + unit = problem['unit'] + if unit and len(unit) > 0: + question = f"{question} (Unit: {unit})" + + choices = problem['choices'] + if choices and len(choices) > 0: + choice_list = [] + for i, c in enumerate(choices): + choice_list.append("({}) {}".format(option_inds[i], c)) + options = " ".join(choice_list) + question = f"{question}\nOptions: {options}" + + return question + + +def get_answer(problem: dict) -> str: + return problem['answer'] + + +def get_solution_text(problem: dict) -> str: + # GPT-3 can generate the solution with more tokens + solution = problem['solution'].replace("\n", "\\n") + return solution + + +def create_one_example(format: str, table: str, question: str, answer: str, + solution: str, test_example:bool = True) -> str: + # Using template to generate one prompt example. + input_format, output_format = format.split("-") # e.g., "TQ-A" + + elements = { + "Q": f"Question: {question}", + "T": f"Table: {table}", + "S": f"Solution: {solution}", + "A": f"Answer: The answer is {answer}.", + "AS": f"Answer: The answer is {answer}. BECAUSE: {solution}", + "SA": f"Answer: {solution} The answer is {answer}." + } + + # Input + input = "\n".join(elements[label] for label in input_format) + + # Output + if test_example: + output = "Answer:" + else: + output = elements[output_format] + + # Prompt text + text = input + "\n" + output + text = text.replace(" ", " ").strip() + + return text + + +def build_prompt(problems: list, shot_pids: list, test_pid: int, args: dict) -> str: + # Given ids, generate the complete prompt. That is, the input to LM. + examples = [] + pids = shot_pids + [test_pid] + + # n-shot training examples + for pid in pids: + problem = problems[pid] + table = get_table_text(problem) + question = get_question_text(problem, args.option_inds) + answer = get_answer(problem) + solution = get_solution_text(problems[pid]) + + if pid == test_pid: + assert pid not in shot_pids + example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True) + else: + example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False) + + examples.append(example) + + # create the prompt input + prompt_input = '\n\n'.join(examples) + + return prompt_input + + +def extract_prediction(output: str, options: list, option_inds: list) -> str: + idx = output.find('\n') + if idx > 0: + output = output[:idx] + idx = output.find('=') + if idx > 0: + output = output[idx + 1:].strip() + # $\\frac{16}{95}$ -> 16/95 + output = re.sub(r"\$?\\frac\{([\d\.\,\-]+)\}\{([\d\.\,]+)\}\$?", r"\1/\2", output) + + output = re.sub(r"(? 0: + pred = res[0].upper() # e.g., "B" + if pred in option_inds: + ind = option_inds.index(pred) # 1 + if ind >= len(options): + ind = random.choice(range(len(options))) + predition = options[ind] + return predition + + # find the most similar options + scores = [score_string_similarity(x, output) for x in options] + max_idx = int(np.argmax(scores)) # json does not recognize NumPy data types + predition = options[max_idx] + return predition + + else: + # free_text QA problems, numeric answer + patterns = [ + # r'^\([A-Za-z]\) ([\s\S]+)$', # "(A) XXXXX" + # r'[Th]he answer is \([A-Za-z]\) ([\s\S]+)$', # "The answer is (B) XXXXX." + r'[Th]he answer is ([\s\S]+)$', # "The answer is XXXXX.", + r'[Th]he table shows that ([\d\$\.\,\/\:]+) ', + r' = ([\d\$\.\,\/\:]+)', # "= $1.40" + r'(?<= be| is) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "will be $1.40" + r'(?<= are| was) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40" + r'(?<= were) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40" + r' ([\d\$\.\,\/\:]+ [AP]\.M\.)', # 7:25 P.M. + r'([\-\d\$\.\,\/\:]{0,}[\d]+)', # 14.5 + ] + + for p in patterns: + pattern = re.compile(p) + res = pattern.findall(output) + if len(res) > 0: + predition = res[-1].strip() + if predition.endswith(".") and ".M." not in predition: + predition = predition[:-1] + return predition + + return output + + +def normalize_answer(text: str, unit: str) -> str: + # ["1,000", "123", "3/4", "56.456", "$56.4", "-3", "-10.02", "-3/2"] + + text = re.sub("^[\$]", "", text) + text = re.sub("[\,\.\,\/]$", "", text) + result = re.match("^[-+]?[\d,./]+$", text) + + if result is not None: + # is number? + text = text.replace(",", "") + result = re.match("[-+]?\d+$", text) + try: + if result is not None: + number = int(text) + elif "/" in text: + nums = text.split("/") + number = round(float(nums[0]) / float(nums[1]), 3) + else: + number = round(float(text), 3) + number = str(number) + number = re.sub(r"\.[0]+$", "", number) + return number + except: + return text + else: + # is text + if unit: + text = text.replace(unit, "").strip() + return text + + +def score_string_similarity(str1: str, str2: str) -> float: + if str1 == str2: + return 2.0 + if " " in str1 or " " in str2: + str1_split = str1.split(" ") + str2_split = str2.split(" ") + overlap = list(set(str1_split) & set(str2_split)) + return len(overlap) / max(len(str1_split), len(str2_split)) + else: + if str1 == str2: + return 1.0 + else: + return 0.0 + + +def create_example_from_pid(pid: int, problems: list, args: dict, test: bool = False) -> str: + problem = problems[pid] + table = get_table_text(problem) + question = get_question_text(problem, args.option_inds) + answer = get_answer(problem) + solution = get_solution_text(problems[pid]) + + if test: + example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True) + else: + example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False) + + return example