Skip to content

Commit

Permalink
redo black
Browse files Browse the repository at this point in the history
  • Loading branch information
jkterry1 committed Jul 29, 2021
1 parent d5004b7 commit e9d2c41
Show file tree
Hide file tree
Showing 109 changed files with 459 additions and 1,363 deletions.
4 changes: 1 addition & 3 deletions bin/render.py
Expand Up @@ -3,9 +3,7 @@
import gym


parser = argparse.ArgumentParser(
description="Renders a Gym environment for quick inspection."
)
parser = argparse.ArgumentParser(description="Renders a Gym environment for quick inspection.")
parser.add_argument(
"env_id",
type=str,
Expand Down
11 changes: 2 additions & 9 deletions examples/agents/cem.py
Expand Up @@ -35,12 +35,7 @@ def cem(f, th_mean, batch_size, n_iter, elite_frac, initial_std=1.0):
th_std = np.ones_like(th_mean) * initial_std

for _ in range(n_iter):
ths = np.array(
[
th_mean + dth
for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)
]
)
ths = np.array([th_mean + dth for dth in th_std[None, :] * np.random.randn(batch_size, th_mean.size)])
ys = np.array([f(th) for th in ths])
elite_inds = ys.argsort()[::-1][:n_elite]
elite_ths = ths[elite_inds]
Expand Down Expand Up @@ -101,9 +96,7 @@ def noisy_evaluation(theta):
return rew

# Train the agent, and snapshot each stage
for (i, iterdata) in enumerate(
cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)
):
for (i, iterdata) in enumerate(cem(noisy_evaluation, np.zeros(env.observation_space.shape[0] + 1), **params)):
print("Iteration %2i. Episode mean reward: %7.3f" % (i, iterdata["y_mean"]))
agent = BinaryActionLinearPolicy(iterdata["theta_mean"])
if args.display:
Expand Down
4 changes: 1 addition & 3 deletions examples/agents/random_agent.py
Expand Up @@ -17,9 +17,7 @@ def act(self, observation, reward, done):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description=None)
parser.add_argument(
"env_id", nargs="?", default="CartPole-v0", help="Select the environment to run"
)
parser.add_argument("env_id", nargs="?", default="CartPole-v0", help="Select the environment to run")
args = parser.parse_args()

# You can set the level to logger.DEBUG or logger.WARN if you
Expand Down
14 changes: 3 additions & 11 deletions gym/core.py
Expand Up @@ -173,16 +173,10 @@ class GoalEnv(Env):
def reset(self):
# Enforce that each GoalEnv uses a Goal-compatible observation space.
if not isinstance(self.observation_space, gym.spaces.Dict):
raise error.Error(
"GoalEnv requires an observation space of type gym.spaces.Dict"
)
raise error.Error("GoalEnv requires an observation space of type gym.spaces.Dict")
for key in ["observation", "achieved_goal", "desired_goal"]:
if key not in self.observation_space.spaces:
raise error.Error(
'GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(
key
)
)
raise error.Error('GoalEnv requires the "{}" key to be part of the observation dictionary.'.format(key))

def compute_reward(self, achieved_goal, desired_goal, info):
"""Compute the step reward. This externalizes the reward function and makes
Expand Down Expand Up @@ -227,9 +221,7 @@ def __init__(self, env):

def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(
"attempted to get missing private attribute '{}'".format(name)
)
raise AttributeError("attempted to get missing private attribute '{}'".format(name))
return getattr(self.env, name)

@property
Expand Down
4 changes: 1 addition & 3 deletions gym/envs/__init__.py
Expand Up @@ -422,9 +422,7 @@ def _merge(a, b):
register(
id="HandManipulateBlockRotateParallel{}-v0".format(suffix),
entry_point="gym.envs.robotics:HandBlockEnv",
kwargs=_merge(
{"target_position": "ignore", "target_rotation": "parallel"}, kwargs
),
kwargs=_merge({"target_position": "ignore", "target_rotation": "parallel"}, kwargs),
max_episode_steps=100,
)

Expand Down
22 changes: 5 additions & 17 deletions gym/envs/algorithmic/algorithmic_env.py
Expand Up @@ -73,9 +73,7 @@ def __init__(self, base=10, chars=False, starting_min_length=2):
# 1. Move read head left or right (or up/down)
# 2. Write or not
# 3. Which character to write. (Ignored if should_write=0)
self.action_space = Tuple(
[Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)]
)
self.action_space = Tuple([Discrete(len(self.MOVEMENTS)), Discrete(2), Discrete(self.base)])
# Can see just what is on the input tape (one of n characters, or
# nothing)
self.observation_space = Discrete(self.base + 1)
Expand Down Expand Up @@ -147,10 +145,7 @@ def render(self, mode="human"):
move = self.MOVEMENTS[inp_act]
outfile.write("Action : Tuple(move over input: %s,\n" % move)
out_act = out_act == 1
outfile.write(
" write to the output tape: %s,\n"
% out_act
)
outfile.write(" write to the output tape: %s,\n" % out_act)
outfile.write(" prediction: %s)\n" % pred_str)
else:
outfile.write("\n" * 5)
Expand Down Expand Up @@ -276,9 +271,7 @@ def render_observation(self):
x_str = "Observation Tape : "
for i in range(-2, self.input_width + 2):
if i == x:
x_str += colorize(
self._get_str_obs(np.array([i])), "green", highlight=True
)
x_str += colorize(self._get_str_obs(np.array([i])), "green", highlight=True)
else:
x_str += self._get_str_obs(np.array([i]))
x_str += "\n"
Expand Down Expand Up @@ -311,10 +304,7 @@ def _move(self, movement):
self.read_head_position = x, y

def generate_input_data(self, size):
return [
[self.np_random.randint(self.base) for _ in range(self.rows)]
for __ in range(size)
]
return [[self.np_random.randint(self.base) for _ in range(self.rows)] for __ in range(size)]

def _get_obs(self, pos=None):
if pos is None:
Expand All @@ -336,9 +326,7 @@ def render_observation(self):
x_str += " " * len(label)
for i in range(-2, self.input_width + 2):
if i == x[0] and j == x[1]:
x_str += colorize(
self._get_str_obs((i, j)), "green", highlight=True
)
x_str += colorize(self._get_str_obs((i, j)), "green", highlight=True)
else:
x_str += self._get_str_obs((i, j))
x_str += "\n"
Expand Down
17 changes: 4 additions & 13 deletions gym/envs/algorithmic/tests/test_algorithmic.py
Expand Up @@ -10,12 +10,8 @@
alg.reverse.ReverseEnv,
alg.reversed_addition.ReversedAdditionEnv,
]
ALL_TAPE_ENVS = [
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)
]
ALL_GRID_ENVS = [
env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)
]
ALL_TAPE_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.TapeAlgorithmicEnv)]
ALL_GRID_ENVS = [env for env in ALL_ENVS if issubclass(env, alg.algorithmic_env.GridAlgorithmicEnv)]


def imprint(env, input_arr):
Expand Down Expand Up @@ -92,10 +88,7 @@ def test_walk_off_the_end(self):

def test_grid_naviation(self):
env = alg.reversed_addition.ReversedAdditionEnv(rows=2, base=6)
N, S, E, W = [
env._movement_idx(named_dir)
for named_dir in ["up", "down", "right", "left"]
]
N, S, E, W = [env._movement_idx(named_dir) for named_dir in ["up", "down", "right", "left"]]
# Corresponds to a grid that looks like...
# 0 1 2
# 3 4 5
Expand Down Expand Up @@ -204,9 +197,7 @@ def test_duplicated_input_target(self):

def test_repeat_copy_target(self):
env = alg.repeat_copy.RepeatCopyEnv()
self.assertEqual(
env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2]
)
self.assertEqual(env.target_from_input_data([0, 1, 2]), [0, 1, 2, 2, 1, 0, 0, 1, 2])


class TestInputGeneration(unittest.TestCase):
Expand Down
35 changes: 11 additions & 24 deletions gym/envs/atari/atari_env.py
Expand Up @@ -9,8 +9,7 @@
import atari_py
except ImportError as e:
raise error.DependencyNotInstalled(
"{}. (HINT: you can install Atari dependencies by running "
"'pip install gym[atari]'.)".format(e)
"{}. (HINT: you can install Atari dependencies by running " "'pip install gym[atari]'.)".format(e)
)


Expand Down Expand Up @@ -64,35 +63,23 @@ def __init__(

# Tune (or disable) ALE's action repeat:
# https://github.com/openai/gym/issues/349
assert isinstance(
repeat_action_probability, (float, int)
), "Invalid repeat_action_probability: {!r}".format(repeat_action_probability)
self.ale.setFloat(
"repeat_action_probability".encode("utf-8"), repeat_action_probability
assert isinstance(repeat_action_probability, (float, int)), "Invalid repeat_action_probability: {!r}".format(
repeat_action_probability
)
self.ale.setFloat("repeat_action_probability".encode("utf-8"), repeat_action_probability)

self.seed()

self._action_set = (
self.ale.getLegalActionSet()
if full_action_space
else self.ale.getMinimalActionSet()
)
self._action_set = self.ale.getLegalActionSet() if full_action_space else self.ale.getMinimalActionSet()
self.action_space = spaces.Discrete(len(self._action_set))

(screen_width, screen_height) = self.ale.getScreenDims()
if self._obs_type == "ram":
self.observation_space = spaces.Box(
low=0, high=255, dtype=np.uint8, shape=(128,)
)
self.observation_space = spaces.Box(low=0, high=255, dtype=np.uint8, shape=(128,))
elif self._obs_type == "image":
self.observation_space = spaces.Box(
low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8
)
self.observation_space = spaces.Box(low=0, high=255, shape=(screen_height, screen_width, 3), dtype=np.uint8)
else:
raise error.Error(
"Unrecognized observation type: {}".format(self._obs_type)
)
raise error.Error("Unrecognized observation type: {}".format(self._obs_type))

def seed(self, seed=None):
self.np_random, seed1 = seeding.np_random(seed)
Expand All @@ -107,9 +94,9 @@ def seed(self, seed=None):
if self.game_mode is not None:
modes = self.ale.getAvailableModes()

assert self.game_mode in modes, (
'Invalid game mode "{}" for game {}.\nAvailable modes are: {}'
).format(self.game_mode, self.game, modes)
assert self.game_mode in modes, ('Invalid game mode "{}" for game {}.\nAvailable modes are: {}').format(
self.game_mode, self.game, modes
)
self.ale.setMode(self.game_mode)

if self.game_difficulty is not None:
Expand Down

0 comments on commit e9d2c41

Please sign in to comment.