nnx MultiHeadAttention currently requires a single "in_features" provided at construction time, but I'd like to construct the query from an input with a feature dim H1 and the key/value from an input with feature dim H2. the previous linen implementation I believe supports this due to shape inference. For nnx, can there be support for this use case? Either via an additional kwarg, or generalizing the allowed inputs for "in_features"