Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
iffiX committed Jun 18, 2021
1 parent 2da92eb commit a308d55
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 45 deletions.
2 changes: 2 additions & 0 deletions test/data/all.py
Expand Up @@ -19,6 +19,7 @@ def first(iterable, condition=lambda x: True):

def generate_all():
print("Generating all needed data...")
os.makedirs(os.path.join(ROOT, "generated"), exist_ok=True)
for gen in dir(generators):
method = getattr(getattr(generators, gen), "generate", None)
name = getattr(getattr(generators, gen), "generated_name", None)
Expand All @@ -42,6 +43,7 @@ def generate_all():

def get_all():
archives = {}
os.makedirs(os.path.join(ROOT, "generated"), exist_ok=True)
for gen in dir(generators):
method = getattr(getattr(generators, gen), "generate", None)
name = getattr(getattr(generators, gen), "generated_name", None)
Expand Down
35 changes: 15 additions & 20 deletions test/env/wrappers/test_openai_gym.py
Expand Up @@ -11,13 +11,12 @@
linux_only_forall()

from machin.env.wrappers import openai_gym
from machin.utils.logging import default_logger
from random import choice, sample
from colorlog import getLogger
import pytest
import gym
import numpy as np

logger = getLogger("default")
ENV_NUM = 2
SAMPLE_NUM = 2
WORKER_NUM = 2
Expand Down Expand Up @@ -51,35 +50,29 @@ def should_skip(spec):
# Skip mujoco tests
if ep.startswith("gym.envs.mujoco") or ep.startswith("gym.envs.robotics:"):
return True
try:
import atari_py
except ImportError:
if ep.startswith("gym.envs.atari"):
return True

# Skip atari tests
if ep.startswith("gym.envs.atari"):
return True

# Skip other tests
if "GoEnv" in ep or "HexEnv" in ep:
return True

# Conditionally skip box2d tests
try:
import Box2D
except ImportError:
if ep.startswith("gym.envs.box2d"):
return True

if (
"GoEnv" in ep
or "HexEnv" in ep
or (
ep.startswith("gym.envs.atari")
and not spec.id.startswith("Pong")
and not spec.id.startswith("Seaquest")
)
):
return True
return False


@pytest.fixture(scope="module", autouse=True)
def envs():
all_envs = []
env_map = {}
lg = getLogger(__file__)
# Find the newest version of non-skippable environments.
for env_raw_name, env_spec in gym.envs.registry.env_specs.items():
if not should_skip(env_spec):
Expand All @@ -90,9 +83,11 @@ def envs():
# Create environments.
for env_name, env_version in env_map.items():
env_name = env_name + "-v" + str(env_version)
lg.info(f"OpenAI gym {env_name} added")
default_logger.info(f"OpenAI gym {env_name} added")
all_envs.append([lambda *_: gym.make(env_name) for _ in range(ENV_NUM)])
lg.info("{} OpenAI gym environments to be tested.".format(len(all_envs)))
default_logger.info(
"{} OpenAI gym environments to be tested.".format(len(all_envs))
)
return all_envs


Expand Down
8 changes: 5 additions & 3 deletions test/frame/buffers/test_buffer.py
Expand Up @@ -137,7 +137,7 @@ def const_buffer(self, pytestconfig):
"state": {"state_1": t.zeros([1, 2])},
"action": {"action_1": t.zeros([1, 3])},
"next_state": {"next_state_1": t.zeros([1, 4])},
"reward": 10,
"reward": 10.0,
"terminal": True,
"data_index": i,
"not_concatenable": (i, "some_str"),
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_sample(
)
elif attr == "reward":
if concat:
assert self.t_eq(data, t.full([bsize, 1], 10))
assert self.t_eq(data, t.full([bsize, 1], 10.0))
else:
assert (
isinstance(data, list)
Expand All @@ -296,7 +296,9 @@ def test_sample(
)
elif attr == "terminal":
if concat:
assert self.t_eq(data, t.full([bsize, 1], True))
assert self.t_eq(
data, t.full([bsize, 1], True, dtype=t.bool)
)
else:
assert (
isinstance(data, list)
Expand Down
44 changes: 22 additions & 22 deletions test/parallel/distributed/test_world.py
Expand Up @@ -51,7 +51,7 @@ def test_cc_send_recv(rank):
else:
a = t.ones([5])
assert group.recv(a) == 0
assert t.all(a == 0)
assert t.all(a == 0.0)
group.destroy()
return True

Expand All @@ -68,7 +68,7 @@ def test_cc_isend_irecv(rank):
else:
a = t.ones([5])
assert group.irecv(a).wait() == 0
assert t.all(a == 0)
assert t.all(a == 0.0)
group.destroy()
return True

Expand All @@ -84,7 +84,7 @@ def test_cc_broadcast(rank):
else:
a = t.zeros([5])
group.broadcast(a, 0)
assert t.all(a == 1)
assert t.all(a == 1.0)
group.destroy()
return True

Expand All @@ -94,9 +94,9 @@ def test_cc_broadcast(rank):
def test_cc_all_reduce(rank):
world = get_world()
group = world.create_collective_group(ranks=[0, 1, 2])
a = t.full([5], rank)
a = t.full([5], float(rank))
group.all_reduce(a)
assert t.all(a == 3)
assert t.all(a == 3.0)
group.destroy()
return True

Expand All @@ -106,10 +106,10 @@ def test_cc_all_reduce(rank):
def test_cc_reduce(rank):
world = get_world()
group = world.create_collective_group(ranks=[0, 1, 2])
a = t.full([5], 5 - rank)
a = t.full([5], float(5 - rank))
group.reduce(a, 1)
if rank == 1:
assert t.all(a == 12)
assert t.all(a == 12.0)
group.destroy()
return True

Expand All @@ -119,12 +119,12 @@ def test_cc_reduce(rank):
def test_cc_all_gather(rank):
world = get_world()
group = world.create_collective_group(ranks=[0, 1, 2])
a = t.full([5], rank)
a_list = [t.full([5], -1), t.full([5], -1), t.full([5], -1)]
a = t.full([5], float(rank))
a_list = [t.full([5], -1.0), t.full([5], -1.0), t.full([5], -1.0)]
group.all_gather(a_list, a)
assert t.all(a_list[0] == 0)
assert t.all(a_list[1] == 1)
assert t.all(a_list[2] == 2)
assert t.all(a_list[0] == 0.0)
assert t.all(a_list[1] == 1.0)
assert t.all(a_list[2] == 2.0)
group.destroy()
return True

Expand All @@ -134,16 +134,16 @@ def test_cc_all_gather(rank):
def test_cc_gather(rank):
world = get_world()
group = world.create_collective_group(ranks=[0, 1, 2])
a = t.full([5], rank)
a = t.full([5], float(rank))
if rank == 1:
a_list = [t.full([5], -1), t.full([5], -1), t.full([5], -1)]
a_list = [t.full([5], -1.0), t.full([5], -1.0), t.full([5], -1.0)]
else:
a_list = None
group.gather(a, a_list, 1)
if rank == 1:
assert t.all(a_list[0] == 0)
assert t.all(a_list[1] == 1)
assert t.all(a_list[2] == 2)
assert t.all(a_list[0] == 0.0)
assert t.all(a_list[1] == 1.0)
assert t.all(a_list[2] == 2.0)
group.destroy()
return True

Expand All @@ -154,12 +154,12 @@ def test_cc_scatter(rank):
world = get_world()
group = world.create_collective_group(ranks=[0, 1, 2])
if rank == 0:
a_list = [t.full([5], 0), t.full([5], 1), t.full([5], 2)]
a_list = [t.full([5], 0.0), t.full([5], 1.0), t.full([5], 2.0)]
else:
a_list = None
a = t.full([5], -1)
a = t.full([5], -1.0)
group.scatter(a, a_list, 0)
assert t.all(a == 0) or t.all(a == 1) or t.all(a == 2)
assert t.all(a == 0.0) or t.all(a == 1.0) or t.all(a == 2.0)
group.destroy()
return True

Expand All @@ -186,7 +186,7 @@ def test_cc_broadcast_multigpu(rank, gpu):
else:
a = [t.zeros([5], device=gpu)]
group.broadcast_multigpu(a, 0)
assert t.all(a[0] == 1)
assert t.all(a[0] == 1.0)
group.destroy()
return True

Expand All @@ -199,7 +199,7 @@ def test_cc_all_reduce_multigpu(_, gpu):
group = world.create_collective_group(ranks=[0, 1, 2])
a = [t.ones([5], device=gpu)]
group.all_reduce_multigpu(a)
assert t.all(a[0] == 3)
assert t.all(a[0] == 3.0)
group.destroy()
return True

Expand Down

0 comments on commit a308d55

Please sign in to comment.