Skip to content

Commit

Permalink
Feature/slurm experiment (#190)
Browse files Browse the repository at this point in the history
* make slurm outdir and rundir configurable

* fix mkdir

* fix mkdir for real this time

* fix script name

* tweak runtime on release script

* remove internal presets from exports

* explicitly set device in benchmark scripts
  • Loading branch information
cpnota committed Dec 31, 2020
1 parent fb28f66 commit 8254c6f
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 45 deletions.
35 changes: 24 additions & 11 deletions all/experiments/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import sys
from .run_experiment import run_experiment

SCRIPT_NAME = 'experiment.sh'
OUT_DIR = 'out'

# track the number of experiments created
# in the current process
Expand All @@ -21,7 +19,11 @@ def __init__(
envs,
frames,
test_episodes=100,
write_loss=False,
job_name='autonomous-learning-library',
script_name='experiment.sh',
outdir='out',
logdir='runs',
sbatch_args=None,
):
if not isinstance(agents, list):
Expand All @@ -34,7 +36,11 @@ def __init__(
self.envs = envs
self.frames = frames
self.test_episodes = test_episodes
self.write_loss = write_loss
self.job_name = job_name
self.script_name = script_name
self.outdir = outdir
self.logdir = logdir
self.sbatch_args = sbatch_args or {}
self.parse_args()

Expand All @@ -61,22 +67,29 @@ def run_experiment(self):
task_id = int(os.environ['SLURM_ARRAY_TASK_ID'])
env = self.envs[int(task_id / len(self.agents))]
agent = self.agents[task_id % len(self.agents)]
run_experiment(agent, env, self.frames, test_episodes=self.test_episodes, write_loss=False)
run_experiment(
agent,
env,
self.frames,
test_episodes=self.test_episodes,
logdir=self.logdir,
write_loss=self.write_loss
)

def queue_jobs(self):
self.create_sbatch_script()
self.make_output_directory()
self.run_sbatch_script()

def create_sbatch_script(self):
script = open(SCRIPT_NAME, 'w')
script = open(self.script_name, 'w')
script.write('#!/bin/sh\n\n')
num_experiments = len(self.envs) * len(self.agents)

sbatch_args = {
'job-name': self.job_name,
'output': 'out/all_%A_%a.out',
'error': 'out/all_%A_%a.err',
'output': os.path.join(self.outdir, 'all_%A_%a.out'),
'error': os.path.join(self.outdir, 'all_%A_%a.err'),
'array': '0-' + str(num_experiments - 1),
'partition': '1080ti-short',
'ntasks': 1,
Expand All @@ -91,18 +104,18 @@ def create_sbatch_script(self):

script.write('python ' + sys.argv[0] + ' --experiment_id ' + str(self._id) + '\n')
script.close()
print('created sbatch script:', SCRIPT_NAME)
print('created sbatch script:', self.script_name)

def make_output_directory(self):
try:
os.mkdir(OUT_DIR)
print('Created output directory:', OUT_DIR)
os.mkdir(self.outdir)
print('Created output directory:', self.outdir)
except FileExistsError:
print('Output directory already exists:', OUT_DIR)
print('Output directory already exists:', self.outdir)

def run_sbatch_script(self):
result = subprocess.run(
['sbatch', SCRIPT_NAME],
['sbatch', self.script_name],
stdout=subprocess.PIPE,
check=True
)
Expand Down
10 changes: 0 additions & 10 deletions all/presets/atari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,13 @@

__all__ = [
"a2c",
"A2CAtariPreset",
"c51",
"C51AtariPreset",
"ddqn",
"DDQNAtariPreset",
"dqn",
"DQNAtariPreset",
"ppo",
"PPOAtariPreset",
"rainbow",
"RainbowAtariPreset",
"vac",
"VACAtariPreset",
"vpg",
"VPGAtariPreset",
"vqn",
"VQNAtariPreset",
"vsarsa",
"VSarsaAtariPreset",
]
10 changes: 0 additions & 10 deletions all/presets/classic_control/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,13 @@

__all__ = [
"a2c",
"A2CClassicControlPreset",
"c51",
"C51ClassicControlPreset",
"ddqn",
"DDQNClassicControlPreset",
"dqn",
"DQNClassicControlPreset",
"ppo",
"PPOClassicControlPreset",
"rainbow",
"RainbowClassicControlPreset",
"vac",
"VACClassicControlPreset",
"vpg",
"VPGClassicControlPreset",
"vqn",
"VQNClassicControlPreset",
"vsarsa",
"VSarsaClassicControlPreset",
]
2 changes: 0 additions & 2 deletions all/presets/continuous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

__all__ = [
'ddpg',
'DDPGContinuousPreset',
'ppo',
'PPOContinuousPreset',
'sac',
]
15 changes: 8 additions & 7 deletions benchmarks/atari40.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@


def main():
device = 'cuda'
agents = [
atari.a2c(),
atari.c51(),
atari.dqn(),
atari.ddqn(),
atari.ppo(),
atari.rainbow(),
atari.a2c(device=device),
atari.c51(device=device),
atari.dqn(device=device),
atari.ddqn(device=device),
atari.ppo(device=device),
atari.rainbow(device=device),
]
envs = [AtariEnvironment(env, device='cuda') for env in ['BeamRider', 'Breakout', 'Pong', 'Qbert', 'SpaceInvaders']]
envs = [AtariEnvironment(env, device=device) for env in ['BeamRider', 'Breakout', 'Pong', 'Qbert', 'SpaceInvaders']]
SlurmExperiment(agents, envs, 10e6, sbatch_args={
'partition': '1080ti-long'
})
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def main():
frames = int(1e7)

agents = [
ddpg(last_frame=frames),
ppo(last_frame=frames),
sac(last_frame=frames)
ddpg(device=device),
ppo(device=device),
sac(device=device)
]

envs = [GymEnvironment(env, device) for env in [
Expand Down
4 changes: 2 additions & 2 deletions scripts/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ def main():
device = 'cuda'

def get_agents(preset):
agents = [getattr(preset, agent_name) for agent_name in classic_control.__all__]
agents = [getattr(preset, agent_name) for agent_name in preset.__all__]
return [agent(device=device) for agent in agents]

SlurmExperiment(
get_agents(atari),
AtariEnvironment('Breakout', device=device),
2e7,
10e7,
sbatch_args={
'partition': '1080ti-long'
}
Expand Down

0 comments on commit 8254c6f

Please sign in to comment.