Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Add new ARENA_MEDIUM task
  • Loading branch information
cswinter committed Nov 17, 2019
1 parent 21b13f4 commit 4f4bd04
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
41 changes: 30 additions & 11 deletions gym_codecraft/envs/codecraft_vec_env.py
Expand Up @@ -50,15 +50,6 @@ def random_drone():
return drone


def map_arena_tiny_random():
return {
'mapWidth': 1000,
'mapHeight': 1000,
'player1Drones': [random_drone()],
'player2Drones': [random_drone()],
}


def map_arena_tiny(randomize: bool):
storage_modules = 1
constructors = 1
Expand Down Expand Up @@ -115,6 +106,29 @@ def map_arena_tiny_2v2(randomize: bool):
}


def map_arena_medium(randomize: bool):
s1 = 1
if randomize:
s1 = np.random.randint(0, 2)
return {
'mapWidth': 1500,
'mapHeight': 1500,
'player1Drones': [
drone_dict(np.random.randint(-700, 700),
np.random.randint(-700, 700),
constructors=2,
storage_modules=2),
],
'player2Drones': [
drone_dict(np.random.randint(-700, 700),
np.random.randint(-700, 700),
constructors=2 * s1,
storage_modules=2 * s1,
missile_batteries=1 - s1),
],
}


class CodeCraftVecEnv(object):
def __init__(self,
num_envs,
Expand Down Expand Up @@ -145,6 +159,9 @@ def __init__(self,
elif objective == Objective.ARENA_TINY_2V2:
self.game_length = 1 * 30 * 60
self.custom_map = map_arena_tiny_2v2
elif objective == Objective.ARENA_MEDIUM:
self.game_length = 3 * 60 * 60
self.custom_map = map_arena_medium

self.games = []
self.eplen = []
Expand Down Expand Up @@ -247,7 +264,7 @@ def observe(self, env_subset=None, obs_config=None):
game = env_subset[i] if env_subset else i
x = obs[stride * i + GLOBAL_FEATURES + 0]
y = obs[stride * i + GLOBAL_FEATURES + 1]
if self.objective == Objective.ARENA_TINY or self.objective == Objective.ARENA_TINY_2V2:
if self.objective.vs():
allied_score = obs[stride * num_envs + i * NONOBS_FEATURES + 1]
enemy_score = obs[stride * num_envs + i * NONOBS_FEATURES + 2]
score = 2 * allied_score / (allied_score + enemy_score + 1e-8) - 1
Expand Down Expand Up @@ -374,6 +391,7 @@ class Objective(Enum):
DISTANCE_TO_1000_500 = 'DISTANCE_TO_1000_500'
ARENA_TINY = 'ARENA_TINY'
ARENA_TINY_2V2 = 'ARENA_TINY_2V2'
ARENA_MEDIUM = 'ARENA_MEDIUM_alpha'

def vs(self):
if self == Objective.ALLIED_WEALTH or\
Expand All @@ -382,7 +400,8 @@ def vs(self):
self == Objective.DISTANCE_TO_1000_500:
return False
elif self == Objective.ARENA_TINY or\
self == Objective.ARENA_TINY_2V2:
self == Objective.ARENA_TINY_2V2 or\
self == Objective.ARENA_MEDIUM:
return True
else:
raise Exception(f'Objective.vs not implemented for {self}')
Expand Down
6 changes: 5 additions & 1 deletion main.py
Expand Up @@ -342,6 +342,10 @@ def eval(policy, num_envs, device, objective, eval_steps, curr_step=None, oppone
'easy': {'model_file': 'v3/helpful-glade-10M.pt'},
'medium': {'model_file': 'v3/bright-elevator-43M.pt'},
}
elif objective == envs.Objective.ARENA_MEDIUM:
opponents = {
'random': {'model_file': 'v3/random-v3.pt'},
}
else:
raise Exception(f'No eval opponents configured for {objective}')

Expand Down Expand Up @@ -423,7 +427,7 @@ def eval(policy, num_envs, device, objective, eval_steps, curr_step=None, oppone

scores = np.array(scores)

if curr_step:
if curr_step is not None:
wandb.log({
'eval_mean_score': scores.mean(),
'eval_max_score': scores.max(),
Expand Down

0 comments on commit 4f4bd04

Please sign in to comment.