Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Switch to new task
  • Loading branch information
cswinter committed Aug 28, 2019
1 parent bb96bee commit 9c682f7
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
5 changes: 2 additions & 3 deletions gym_codecraft/envs/codecraft_vec_env.py
Expand Up @@ -93,16 +93,15 @@ def observe(self):
dones = []
infos = []
obs = codecraft.observe_batch_raw(self.games)
global_features = 1
global_features = 2
dstride = 7
mstride = 4
stride = global_features + dstride + 10 * mstride
for i in range(self.num_envs):
x = obs[stride * i + global_features + 0]
y = obs[stride * i + global_features + 1]
if self.objective == Objective.ALLIED_WEALTH:
# score = float(observation['alliedScore']) * 0.1
raise Exception("Not implemented")
score = obs[stride * i + 1] * 0.1
elif self.objective == Objective.DISTANCE_TO_ORIGIN:
score = -dist(x, y, 0.0, 0.0)
elif self.objective == Objective.DISTANCE_TO_1000_500:
Expand Down
2 changes: 1 addition & 1 deletion hyper_params.py
Expand Up @@ -31,7 +31,7 @@ def __init__(self):
self.cliprange = 0.2 # PPO cliprange

# Task
self.objective = envs.Objective.DISTANCE_TO_CRYSTAL
self.objective = envs.Objective.ALLIED_WEALTH
self.game_length = 3 * 60 * 60
self.action_delay = 0

Expand Down
12 changes: 6 additions & 6 deletions policy.py
Expand Up @@ -10,12 +10,12 @@ def __init__(self, fc_layers, nhidden, conv):
super(Policy, self).__init__()
self.conv = conv
if conv:
self.fc_drone = nn.Linear(8, nhidden // 2)
self.fc_drone = nn.Linear(9, nhidden // 2)
self.conv_minerals1 = nn.Conv2d(in_channels=1, out_channels=nhidden // 2, kernel_size=(1, 4))
self.conv_minerals2 = nn.Conv2d(in_channels=nhidden // 2, out_channels=nhidden // 2, kernel_size=1)
self.fc_layers = nn.ModuleList([nn.Linear(nhidden, nhidden) for _ in range(fc_layers - 1)])
else:
self.fc_layers = nn.ModuleList([nn.Linear(48, nhidden)])
self.fc_layers = nn.ModuleList([nn.Linear(49, nhidden)])
for _ in range(fc_layers - 1):
self.fc_layers.append(nn.Linear(nhidden, nhidden))

Expand Down Expand Up @@ -67,12 +67,12 @@ def logits(self, x):
def latents(self, x):
if self.conv:
batch_size = x.size()[0]
# x[0:8] is properties of drone 0 and global features
xd = x[:, :8]
# x[0:9] is properties of drone 0 and global features
xd = x[:, :9]
xd = F.relu(self.fc_drone(xd))

# x[8:48] are 10 x 4 properties concerning the closest minerals
xm = x[:, 8:48].view(batch_size, 1, -1, 4)
# x[9:49] are 10 x 4 properties concerning the closest minerals
xm = x[:, 9:].view(batch_size, 1, -1, 4)
xm = F.relu(self.conv_minerals1(xm))
xm = F.max_pool2d(F.relu(self.conv_minerals2(xm)), kernel_size=(10, 1))
xm = xm.view(batch_size, -1)
Expand Down

0 comments on commit 9c682f7

Please sign in to comment.