Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dimensions of forward_recurrent #36

Closed
Qiu30 opened this issue Nov 21, 2023 · 5 comments
Closed

Dimensions of forward_recurrent #36

Qiu30 opened this issue Nov 21, 2023 · 5 comments

Comments

@Qiu30
Copy link

Qiu30 commented Nov 21, 2023

In MultiScaleRetention class, it is mentioned that 's_n_1s' has dimensions (batch_size, heads, head_size, head_size), while in SimpleRetention, 's_n_1' is defined as 's_n_1s[i]'. However, you mentioned that 's_n_1' has dimensions (batch_size, hidden_size, v_dim). Can you clarify this?

@DinoMan
Copy link

DinoMan commented Nov 22, 2023

@Qiu30 Just had a closer look at the code (and the tests.py) and you need to note that s_n_1s is a list for MultiScaleRetention. What the comment means to say is that each element of the list has the shape (batch_size, heads, head_size, head_size). As for RetNet the state is a list of lists with each element being (batch_size, heads, head_size, head_size). So to summarize:

Retention --> Sn-1: (batch x head_dim x head_dim)
MultiscaleRetention --> Sn-1s: List of with num_head elements (tensors) each with shape (batch x head_dim x head_dim)
RetNet --> Sn-1s: List with num_layers elements. Each element is a list with elements tensors with shape (batch x head_dim x head_dim)

I hope this helps.

@Jamie-Stirling
Copy link
Owner

Hi all, thanks very much for raising this and identifying the issue. I'll update the comments when I get time.

@Qiu30
Copy link
Author

Qiu30 commented Nov 22, 2023

@DinoMan @Jamie-Stirling Thank you for your reply. I have the same idea as you, but I have a question, what is the initial value of s_n_1? I searched the paper and did not see the relevant initial value.

@Jamie-Stirling
Copy link
Owner

Hi, in the code I initialize this to zeros, however this detail is not mentioned in the paper. I'm not sure of the impact of the choice of the initial value on training, but setting to zeros ensures only the keys and values computed from the first token effect the state at t=1, akin to a transformer.

I would say, setting a nonzero constant or trainable value for the initial state is analogous to introducing a bias term, and so may affect the way the RetNet trains. Though I'm not an expert so it may be best to ask the authors of the original paper to make sure.

@Qiu30
Copy link
Author

Qiu30 commented Nov 22, 2023

@Jamie-Stirling I understand, thanks for your reply!

@Qiu30 Qiu30 closed this as completed Nov 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants