Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement relative positional encoding #139

Merged
merged 6 commits into from Jan 13, 2022

Conversation

cswinter
Copy link
Collaborator

@cswinter cswinter commented Jan 12, 2022

Implements a version of relative positional encoding for n-dimensional grids. Relative positional encoding with e.g. a 11 x 13 extent for an environment with a 2d grid can be enabled by passing --relpos-encoding='{"extent": [5, 6], "position_features": ["x", "y"], "per_entity_values": true}' to enn_ppo/train.py.

There are many variations and refinements of relative positional encoding. This implementation mostly follows the original formulation described in Shaw et al (2018). In particular, here is a non-exhaustive list of somewhat arbitrary design choices that we may want to revisit once we have some good benchmarks to test against:

  • There are relative positional keys and values, but no queries.
  • Keys and values are shared across all layers and heads.
  • We initialize values ~ Normal(0, 0.2) and keys ~ Normal(0, 0.05).
  • By default, values are different per entity type (this turns out to be quite important).
  • Relative positional keys/values have the same dimension as the heads.
  • All dimensions are combined, so every combination of x/y/z/... positions within the extent gets its own key/value embedding. We could also have separate keys/values for each dimension, which would reduce the total number of embeddings and might be preferable in some cases.

The current implementation requires ds^2 memory, where d is the dimension of heads and s is the sequence length. Since our sequences are relatively short so far, this does present a major issue. The usual trick used to reduce memory usage by a factor of s only works for sequences and not our more general version where entities can be at arbitrary grid points. We could still achieve the same savings with a custom GPU kernel though.

@cswinter cswinter linked an issue Jan 12, 2022 that may be closed by this pull request
@cswinter cswinter force-pushed the clemens/relative-positional-encoding branch from bfa5269 to 05f1d1d Compare January 13, 2022 02:19
@cswinter
Copy link
Collaborator Author

Some basic ablations here: https://wandb.ai/entity-neural-network/enn-ppo/reports/Relative-positional-encoding-ablations--VmlldzoxNDM0MzIx
Relative positional encoding somewhat outperforms tuned baseline using translation and greatly outperforms policies that only see the raw position features.

A big caveat is that, at least on this task, we seem to need per-entity relative positional values to get good performance.
I believe the reason for this is that per-entity values allow a single attention head/layer to easily access/compute per-entity positional information in a way that is impossible without per-entity values.
Imagine a single head that attends from the actor entity to two entities equally: a snake segment entity and a food entity.
The output of the attention head will be 0.5 * (value[snake] + value[food] + relposvalue[snake.pos] + relposvalue[food.pos]) where value[x] is the normal value vector of x derived from the embedding of entity x, and relposvalue[x] is the relative positional embedding value of the position of entity x.
The actor now has access to the following information:

  • there is a "food" entity
  • there is a "snake segment" entity
  • these two entities are located are at positions snake.pos and food.pos
    Without per-entity relative positional values, the result of the head does not change when the positions of food and snake are switched, which makes it impossible to tell which entitity is at which location. With per-entity positional values however, the head can immediately extract features that tell it where the food and snake entities are located.

While per-entity relative positional values are good solution in the case of this environment, they are also quite limited. If, instead of separate food and snake entities, there was a single entity type with a feature that identifies whether it is "food" or "snake segment", per-entity relpos values wouldn't apply. This seems wrong, we ought to have a solution that works just as well in that case. More generally, the relevant property might not be the entity type, but some arbitrary feature of the entity learned by the network.

I believe we can come up with a new type of relative positional encoding which is fully general by allowing for a non-linear combination of a (projection of) the entity embeddings and the positional features. Since there are N^2 relative positional values, we probably can't afford a full matmul, but there some cheap-elementwise operations that I think could work well. In particular, a good approach could be to perform an element-wise multiplication of the relative positional values and a projection of the corresponding entity embedding using one of the GLU variants described in Shazeer (2020). This would effectively allow entities to apply an arbitrary gating function to any of the relative positional values, and should be strictly more powerful than per-entity positional encodings.

An important related question is whether all of this is even necessary. In principle, a multi-layer or multi-head attention network ought to be able to perform the same operation. E.g., a two-head attention layer could retrieve all snake entities with on of the heads and all food entities with the other head, which allows it to separately access and project the positions of the different entity types. Empirically, I haven't been able to get good performance even with networks with multiple layers/heads. It would be good to understand this better. Some thoughts:

  • Is this due to some inherent limitation that makes it much less efficient to express this operation with multiple heads as opposed to per-entity values, or does it just take additional time to learn to specialize the heads in that way?
    • We could probably just work through the math, and actually pinpoint precisely in what ways the per-entity values and multiple heads differ.
    • Empirically, we we could try to come up with hyperparameters that allow us to train a multi-head/layer network to the same performance as with per-entity relps values and try to see if the difference mostly goes away as you scale compute (so it's just some constant overhead at the start of training), or try to explicitly bias heads to attend only to specific entity types and see if that allows it to train as quickly as per-entity relpos values.
  • It could also be insightful to look at the attention pattern and function computed by the policy to see what algorithm the policy actually uses and whether this matches our intuition of why per-entity values perform better.

@cswinter cswinter merged commit eb62f6f into main Jan 13, 2022
@cswinter cswinter deleted the clemens/relative-positional-encoding branch January 13, 2022 06:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Relative positional encoding for grids
1 participant