Skip to content

Commit

Permalink
Behavioral cloning (#175)
Browse files Browse the repository at this point in the history
Adds a new `supervised.py` script to enn-ppo which trains a model from samples recorded by another policy. Also makes various improvements to the sample recorder:
- add `--eval-capture-samples`/`--eval-capture-logits` options to record samples/logits during eval to a file
- add `--eval-on-step-0` arg to enable/disable running eval on the first step
- add `--codecraft-only-opponent` to run an eval with only a loaded eval policy against itself (this is slightly hacky, I'm planning to remove all the CodeCraft-specific options later)
- include action and observation spaces when recording samples
- fix `RaggedBufferBool` getting deserialized to `None`
- misc fixes to the `SampleRecorder` and `Trace`

Resolves entity-neural-network/incubator#5, entity-neural-network/incubator#6, and entity-neural-network/incubator#8.
  • Loading branch information
cswinter committed Feb 22, 2022
1 parent 5460441 commit 4cb89e8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 21 deletions.
10 changes: 8 additions & 2 deletions rogue_net/actor.py
Expand Up @@ -31,9 +31,15 @@
def tensor_dict_to_ragged(
rb_cls: Type[RaggedBuffer[ScalarType]],
d: Dict[str, torch.Tensor],
lengths: Dict[str, npt.NDArray[np.int64]],
lengths: Dict[str, np.ndarray],
) -> Dict[str, RaggedBuffer[ScalarType]]:
return {k: rb_cls.from_flattened(v.cpu().numpy(), lengths[k]) for k, v in d.items()}
result = {}
for k, v in d.items():
flattened = v.cpu().numpy()
if flattened.ndim == 1:
flattened = flattened.reshape(-1, 1)
result[k] = rb_cls.from_flattened(flattened, lengths[k])
return result


class Actor(nn.Module):
Expand Down
38 changes: 22 additions & 16 deletions rogue_net/head_creator.py
Expand Up @@ -60,15 +60,17 @@ def forward(
lengths = mask.actors.size1()
if len(mask.actors) == 0:
return (
torch.zeros((0, 1), dtype=torch.int64, device=device),
torch.zeros((0), dtype=torch.int64, device=device),
lengths,
torch.zeros((0, 1), dtype=torch.float32, device=device),
torch.zeros((0, 1), dtype=torch.float32, device=device),
torch.zeros((0), dtype=torch.float32, device=device),
torch.zeros((0), dtype=torch.float32, device=device),
torch.zeros((0, self.n_choice), dtype=torch.float32, device=device),
)

actors = torch.tensor((mask.actors + index_offsets).as_array()).to(
x.data.device
actors = (
torch.tensor((mask.actors + index_offsets).as_array())
.to(x.data.device)
.squeeze(-1)
)
actor_embeds = x.data[actors]
logits = self.proj(actor_embeds)
Expand All @@ -84,10 +86,10 @@ def forward(
if prev_actions is None:
action = dist.sample()
else:
action = torch.tensor(prev_actions.as_array()).to(x.data.device)
action = torch.tensor(prev_actions.as_array().squeeze(-1)).to(x.data.device)
logprob = dist.log_prob(action)
entropy = dist.entropy()
return action, lengths, logprob, entropy, logits
return action, lengths, logprob, entropy, dist.logits


class PaddedSelectEntityActionHead(nn.Module):
Expand Down Expand Up @@ -119,14 +121,16 @@ def forward(
actor_lengths = mask.actors.size1()
if len(mask.actors) == 0:
return (
torch.zeros((0, 1), dtype=torch.int64, device=device),
torch.zeros((0), dtype=torch.int64, device=device),
actor_lengths,
torch.zeros((0, 1), dtype=torch.float32, device=device),
torch.zeros((0, 1), dtype=torch.float32, device=device),
torch.zeros((0, 1), dtype=torch.float32, device=device),
torch.zeros((0), dtype=torch.float32, device=device),
torch.zeros((0), dtype=torch.float32, device=device),
torch.zeros((0), dtype=torch.float32, device=device),
)

actors = torch.tensor((mask.actors + index_offsets).as_array(), device=device)
actors = torch.tensor(
(mask.actors + index_offsets).as_array(), device=device
).squeeze(-1)
actor_embeds = x.data[actors]
queries = self.query_proj(actor_embeds).squeeze(1)
max_actors = actor_lengths.max()
Expand All @@ -147,7 +151,9 @@ def forward(
query_mask = query_mask.view(len(actor_lengths), max_actors)

actee_lengths = mask.actees.size1()
actees = torch.tensor((mask.actees + index_offsets).as_array(), device=device)
actees = torch.tensor(
(mask.actees + index_offsets).as_array(), device=device
).squeeze(-1)
actee_embeds = x.data[actees]
keys = self.key_proj(actee_embeds).squeeze(1)
max_actees = actee_lengths.max()
Expand Down Expand Up @@ -187,10 +193,10 @@ def forward(
logprob = dist.log_prob(action)
entropy = dist.entropy()
return (
action.flatten()[qindices].view(-1, 1),
action.flatten()[qindices],
actor_lengths,
logprob.flatten()[qindices].view(-1, 1),
entropy.flatten()[qindices].view(-1, 1),
logprob.flatten()[qindices],
entropy.flatten()[qindices],
logits,
)

Expand Down
6 changes: 3 additions & 3 deletions rogue_net/tests/test_action_head.py
Expand Up @@ -29,8 +29,8 @@ def test_empty_actors() -> None:
),
prev_actions=None,
)
assert action.shape == (0, 1)
assert action.shape == (0,)
assert np.array_equal(lengths, np.array([0, 0, 0, 0]))
assert logprob.shape == (0, 1)
assert entropy.shape == (0, 1)
assert logprob.shape == (0,)
assert entropy.shape == (0,)
assert logits.shape == (0, 2)

0 comments on commit 4cb89e8

Please sign in to comment.