Skip to content

Commit

Permalink
Feature/command line (#131)
Browse files Browse the repository at this point in the history
* add command line scripts

* add plot command line script

* add examples folder

* add watch and release scripts to setup.py

* update getting started

* update readme
  • Loading branch information
cpnota committed Apr 17, 2020
1 parent ca5ec2b commit f937068
Show file tree
Hide file tree
Showing 15 changed files with 127 additions and 74 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ pip install autonomous-learning-library[pytorch]

## Running the Presets

If you just want to test out some cool agents, the `scripts` directory contains the basic code for doing so.
If you just want to test out some cool agents, the library includes several scripts for doing so:

```
python scripts/atari.py Breakout a2c
all-atari Breakout a2c
```

You can watch the training progress using:
Expand All @@ -84,12 +84,16 @@ and opening your browser to http://localhost:6006.
Once the model is trained to your satisfaction, you can watch the trained model play using:

```
python scripts/watch_atari.py Breakout "runs/_a2c [id]"
all-watch-atari Breakout "runs/_a2c [id]"
```

where `id` is the ID of your particular run. You should should be able to find it using tab completion or by looking in the `runs` directory.
The `autonomous-learning-library` also contains presets and scripts for classic control and PyBullet environments.

If you want to test out your own agents, you will need to define your own scripts.
Some examples can be found in the `examples` folder).
See the [docs](https://autonomous-learning-library.readthedocs.io) for information on building your own agents!

## Note

This library was built in the [Autonomous Learning Laboratory](http://all.cs.umass.edu) (ALL) at the [University of Massachusetts, Amherst](https://www.umass.edu).
Expand Down
31 changes: 22 additions & 9 deletions docs/source/guide/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ If you don't have PyTorch or Tensorboard previously installed, you can install t
An alternate approach, that may be useful when following this tutorial, is to instead install by cloning the Github repository:

.. code-block:: bash
git clone https://github.com/cpnota/autonomous-learning-library.git
cd autonomous-learning-library
pip install -e .
pip install -e .["dev"]
If you chose to clone the repository, you can test your installation by running the unit test suite:

You can test your installation by running the tests::
.. code-block:: bash
make test
Expand All @@ -42,22 +46,28 @@ Running a Preset Agent
The goal of the Autonomous Learning Library is to provide components for building new agents.
However, the library also includes a number of "preset" agent configurations for easy benchmarking and comparison,
as well as some useful scripts.
For example, an A2C agent can be run on Cart-Pole as follows::
For example, an A2C agent can be run on Cart-Pole as follows:

.. code-block:: bash
python scripts/classic.py CartPole-v0 ppo
all-classic CartPole-v0 ppo
The results will be written to ``runs/_a2c <id>``, where ``<id>`` is some some string generated by the library.
You can view these results and other information through `tensorboard`:

.. code-block:: bash
tensorboard --logdir runs
By opening your browser to <http://localhost:6006>, you should see a dashboard that looks something like the following (you may need to adjust the "smoothing" parameter):

.. image:: tensorboard.png

If you want to compare agents in a nicer, format, you can use the `plot` script::
If you want to compare agents in a nicer, format, you can use the `plot` script:

.. code-block:: bash
python scripts/plot.py runs
all-plot --logdir runs
This should give you a plot similar to the following:

Expand All @@ -66,10 +76,13 @@ This should give you a plot similar to the following:
In this plot, each point represents the average of the episodic returns over the last 100 episodes.
The shaded region represents the standard deviation over that interval.

Finally, to watch the trained model in action, we provide a `watch` scripts for each preset module::
Finally, to watch the trained model in action, we provide a `watch` scripts for each preset module:

.. code-block:: bash
python scripts/watch_classic.py CartPole-v0 "runs/_a2c <id>"
all-watch-classic CartPole-v0 "runs/_a2c <id>"
You need to find the <id> by checking the ``runs`` directory.

Be sure to check out the `atari` and `continuous` presets for more fun!
Each of these scripts can be found the ``scripts`` directory of the main repository.
Be sure to check out the ``atari`` and ``continuous`` scripts for more fun!
Empty file added examples/__init__.py
Empty file.
19 changes: 19 additions & 0 deletions examples/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
'''
Quick example of usage of the run_experiment API.
'''
from all.experiments import run_experiment, plot_returns_100
from all.presets.classic_control import dqn, a2c
from all.environments import GymEnvironment

def main():
device = 'cpu'
timesteps = 40000
run_experiment(
[dqn(), a2c()],
[GymEnvironment('CartPole-v0', device), GymEnvironment('Acrobot-v1', device)],
timesteps,
)
plot_returns_100('runs', timesteps=timesteps)

if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions examples/slurm_experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
'''
Quick example of a2c running on slurm, a distributed cluster.
Note that it only runs for 1 million frames.
For real experiments, you will surely need a modified version of this script.
'''
from all.experiments import SlurmExperiment
from all.presets.atari import a2c
from all.environments import AtariEnvironment

def main():
device = 'cuda'
envs = [AtariEnvironment(env, device) for env in ['Pong', 'Breakout', 'SpaceInvaders']]
SlurmExperiment(a2c(device=device), envs, 1e6, sbatch_args={
'partition': '1080ti-short'
})

if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions scripts/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from all.experiments import run_experiment
from all.presets import atari

def run_atari():
def main():
parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
parser.add_argument("env", help="Name of the Atari game (e.g. Pong).")
parser.add_argument(
Expand All @@ -30,4 +30,4 @@ def run_atari():


if __name__ == "__main__":
run_atari()
main()
4 changes: 2 additions & 2 deletions scripts/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from all.presets import classic_control


def run_classic():
def main():
parser = argparse.ArgumentParser(description="Run a classic control benchmark.")
parser.add_argument("env", help="Name of the env (e.g. CartPole-v1).")
parser.add_argument(
Expand All @@ -31,4 +31,4 @@ def run_classic():


if __name__ == "__main__":
run_classic()
main()
4 changes: 2 additions & 2 deletions scripts/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
}


def run():
def main():
parser = argparse.ArgumentParser(description="Run a continuous actions benchmark.")
parser.add_argument("env", help="Name of the env (see envs)")
parser.add_argument(
Expand Down Expand Up @@ -53,4 +53,4 @@ def run():


if __name__ == "__main__":
run()
main()
8 changes: 4 additions & 4 deletions scripts/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from all.experiments import plot_returns_100


def plot():
def main():
parser = argparse.ArgumentParser(description="Plots the results of experiments.")
parser.add_argument("dir", help="Output directory.")
parser.add_argument("--logdir", help="Output directory", default='runs')
parser.add_argument("--timesteps", type=int, default=-1, help="The final point will be fixed to this x-value")
args = parser.parse_args()
plot_returns_100(args.dir, timesteps=args.timesteps)
plot_returns_100(args.logdir, timesteps=args.timesteps)

if __name__ == "__main__":
plot()
main()
65 changes: 34 additions & 31 deletions scripts/release.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,42 @@
'''Create slurm tasks to run benchmark suite'''
import argparse
'''Create slurm tasks to run release test suite'''
from all.environments import AtariEnvironment, GymEnvironment
from all.experiments import SlurmExperiment
from all.presets import atari, classic_control, continuous

# run on gpu
device = 'cuda'
def main():
# run on gpu
device = 'cuda'

def get_agents(preset):
agents = [getattr(preset, agent_name) for agent_name in classic_control.__all__]
return [agent(device=device) for agent in agents]
def get_agents(preset):
agents = [getattr(preset, agent_name) for agent_name in classic_control.__all__]
return [agent(device=device) for agent in agents]

SlurmExperiment(
get_agents(atari),
AtariEnvironment('Breakout', device=device),
2e7,
sbatch_args={
'partition': '1080ti-long'
}
)
SlurmExperiment(
get_agents(atari),
AtariEnvironment('Breakout', device=device),
2e7,
sbatch_args={
'partition': '1080ti-long'
}
)

SlurmExperiment(
get_agents(classic_control),
GymEnvironment('CartPole-v0', device=device),
100000,
sbatch_args={
'partition': '1080ti-short'
}
)
SlurmExperiment(
get_agents(classic_control),
GymEnvironment('CartPole-v0', device=device),
100000,
sbatch_args={
'partition': '1080ti-short'
}
)

SlurmExperiment(
get_agents(continuous),
GymEnvironment('LunarLanderContinuous-v2', device=device),
500000,
sbatch_args={
'partition': '1080ti-short'
}
)
SlurmExperiment(
get_agents(continuous),
GymEnvironment('LunarLanderContinuous-v2', device=device),
500000,
sbatch_args={
'partition': '1080ti-short'
}
)

if __name__ == "__main__":
main()
15 changes: 0 additions & 15 deletions scripts/slurm_atari.py

This file was deleted.

4 changes: 2 additions & 2 deletions scripts/watch_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from all.experiments import GreedyAgent, watch


def watch_atari():
def main():
parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
parser.add_argument("env", help="Name of the Atari game (e.g. Pong)")
parser.add_argument("dir", help="Directory where the agent's model was saved.")
Expand All @@ -24,4 +24,4 @@ def watch_atari():
watch(agent, env, fps=args.fps)

if __name__ == "__main__":
watch_atari()
main()
4 changes: 2 additions & 2 deletions scripts/watch_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from all.environments import GymEnvironment
from all.experiments import load_and_watch

def watch():
def main():
parser = argparse.ArgumentParser(description="Run an Atari benchmark.")
parser.add_argument("env", help="Name of the environment (e.g. RoboschoolHalfCheetah-v1")
parser.add_argument("dir", help="Directory where the agent's model was saved.")
Expand All @@ -16,4 +16,4 @@ def watch():
load_and_watch(args.dir, env)

if __name__ == "__main__":
watch()
main()
4 changes: 2 additions & 2 deletions scripts/watch_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from continuous import ENVS


def watch_continuous():
def main():
parser = argparse.ArgumentParser(description="Watch a continuous agent.")
parser.add_argument("env", help="ID of the Environment")
parser.add_argument("dir", help="Directory where the agent's model was saved.")
Expand All @@ -34,4 +34,4 @@ def watch_continuous():
watch(agent, env, fps=args.fps)

if __name__ == "__main__":
watch_continuous()
main()
11 changes: 11 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
url="https://github.com/cpnota/autonomous-learning-library.git",
author="Chris Nota",
author_email="cnota@cs.umass.edu",
entry_points={
'console_scripts': [
'all-atari=scripts.atari:main',
'all-classic=scripts.classic:main',
'all-continuous=scripts.continuous:main',
'all-plot=scripts.plot:main',
'all-watch-atari=scripts.watch_atari:main',
'all-watch-classic=scripts.watch_classic:main',
'all-watch-continuous=scripts.watch_continuous:main',
],
},
install_requires=[
"gym[atari,box2d]", # common environments
"numpy", # math library
Expand Down

0 comments on commit f937068

Please sign in to comment.