Skip to content

Commit

Permalink
Merge branch 'master' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
cpnota committed Apr 18, 2020
2 parents e3c93dc + 57536b2 commit 781fa2f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
20 changes: 20 additions & 0 deletions benchmarks/atari40.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from all.experiments import SlurmExperiment
from all.presets import atari
from all.environments import AtariEnvironment

def main():
agents = [
atari.a2c(),
atari.c51(),
atari.dqn(),
atari.ddqn(),
atari.ppo(),
atari.rainbow(),
]
envs = [AtariEnvironment(env, device='cuda') for env in ['BeamRider', 'Breakout', 'Pong', 'Qbert', 'SpaceInvaders']]
SlurmExperiment(agents, envs, 10e6, sbatch_args={
'partition': '1080ti-long'
})

if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions benchmarks/pybullet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pybullet
import pybullet_envs
from all.experiments import SlurmExperiment
from all.presets.continuous import ddpg, ppo, sac
from all.environments import GymEnvironment

def main():
device = 'cuda'

frames = int(1e7)

agents = [
ddpg(last_frame=frames),
ppo(last_frame=frames),
sac(last_frame=frames)
]

envs = [GymEnvironment(env, device) for env in [
'AntBulletEnv-v0',
"HalfCheetahBulletEnv-v0",
'HumanoidBulletEnv-v0',
'HopperBulletEnv-v0',
'Walker2DBulletEnv-v0'
]]

SlurmExperiment(agents, envs, frames, sbatch_args={
'partition': '1080ti-long'
})

if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="autonomous-learning-library",
version="0.4.0",
version="0.5.0",
description=("A library for building reinforcement learning agents in Pytorch"),
packages=find_packages(),
url="https://github.com/cpnota/autonomous-learning-library.git",
Expand All @@ -17,6 +17,8 @@
'all-watch-atari=scripts.watch_atari:main',
'all-watch-classic=scripts.watch_classic:main',
'all-watch-continuous=scripts.watch_continuous:main',
'all-benchmark-atari=benchmarks.atari40:main',
'all-benchmark-pybullet=benchmarks.pybullet:main',
],
},
install_requires=[
Expand Down

0 comments on commit 781fa2f

Please sign in to comment.