Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fix value function?
  • Loading branch information
cswinter committed Dec 20, 2019
1 parent d829d27 commit fa0d399
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions policy_t.py
Expand Up @@ -133,7 +133,7 @@ def backprop(self,
x, x_privileged = self.latents(obs, privileged_obs)
x = x.view(batch_size, (self.allies + self.minerals) * self.d_model)
print(x.size())
values = self.value_head(x)
values = self.value_head(x).view(-1)
# TODO
#if self.use_privileged:
# vin = torch.cat([pooled.view(batch_size, -1), x_privileged.view(batch_size, -1)], dim=1)
Expand Down Expand Up @@ -188,7 +188,7 @@ def forward(self, x, x_privileged):
# vin = torch.cat([pooled.view(batch_size, -1), x_privileged.view(batch_size, -1)], dim=1)
#else:
# vin = pooled.view(batch_size, -1)
values = self.value_head(x)
values = self.value_head(x).view(-1)

logits = self.policy_head(x)
probs = F.softmax(logits, dim=1)
Expand Down

0 comments on commit fa0d399

Please sign in to comment.