-
Notifications
You must be signed in to change notification settings - Fork 15
Description
Thanks for another great implementation, Phil!
You're using the Attention Block to do attention between latent features (i.e. "process" step from RIN paper).
It looks like you're not Layer-Normalizing the context features in Attention when no context is provided (a logical move :)).
When you initialize the latent attention blocks in RINBlock, You specify norm=True, so you're layer-normalizing the latent features before computing the query vectors. Unfortunately, the context is set to the unnormalized latent features, which are then used to compute keys and values.
I could be wrong, or this could be intentional. Just wanted to give you a heads-up.
TLDR: When you try to use the Attention class to do regular Attention (not Cross Attention), the features used to predict keys and values may not be normalized.