Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Stagger game start times
  • Loading branch information
cswinter committed Apr 27, 2019
1 parent b6732d2 commit 1434f48
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
7 changes: 5 additions & 2 deletions codecraft.py
Expand Up @@ -9,9 +9,12 @@

RETRIES = 100

def create_game() -> int:
def create_game(game_length = None) -> int:
try:
response = requests.post('http://localhost:9000/start-game').json()
if game_length:
response = requests.post(f'http://localhost:9000/start-game?maxTicks={game_length}').json()
else:
response = requests.post('http://localhost:9000/start-game').json()
return int(response['id'])
except requests.exceptions.ConnectionError:
logging.info(f"Connection error on create_game, retrying")
Expand Down
10 changes: 6 additions & 4 deletions gym_codecraft/envs/codecraft_vec_env.py
Expand Up @@ -9,7 +9,7 @@


class CodeCraftVecEnv(VecEnv):
def __init__(self, num_envs):
def __init__(self, num_envs, game_length):
observations_low = []
observations_high = []
# Drone x, y
Expand All @@ -26,25 +26,27 @@ def __init__(self, num_envs):
# size
observations_low.append(0)
observations_high.append(2)

super().__init__(
num_envs,
spaces.Box(
low=np.array(observations_low),
high=np.array(observations_high),
dtype=np.float32),
spaces.Discrete(6))

self.games = []
self.eplen = []
self.eprew = []
self.score = []
self.game_length = game_length

def reset(self):
self.games = []
self.eplen = []
self.score = []
for _ in range(self.num_envs):
game_id = codecraft.create_game()
for i in range(self.num_envs):
# spread out initial game lengths to stagger start times
game_id = codecraft.create_game((self.game_length * (i + 1) // self.num_envs))
# print("Starting game:", game_id)
self.games.append(game_id)
self.eplen.append(1)
Expand Down
19 changes: 10 additions & 9 deletions main.py
Expand Up @@ -48,15 +48,16 @@ def run_codecraft():


def train(hps):
env = envs.CodeCraftVecEnv(64)
ppo2.learn(
network=lambda it: network(hps, it),
env=env,
gamma=0.9,
nsteps=hps["rosteps"],
total_timesteps=hps["steps"],
log_interval=1,
lr=hps["lr"])
num_envs = 64 * 128 // hps["rosteps"]
env = envs.CodeCraftVecEnv(num_envs, 3 * 60 * 60)
ppo2.learn(
network=lambda it: network(hps, it),
env=env,
gamma=0.9,
nsteps=hps["rosteps"],
total_timesteps=hps["steps"],
log_interval=1,
lr=hps["lr"])

def network(hps, input_tensor):
#with tf.variable_scope(scope, reuse=reuse):
Expand Down

0 comments on commit 1434f48

Please sign in to comment.