Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segment-Wise Attention #2

Closed
jlamprou opened this issue Apr 13, 2024 · 11 comments
Closed

Segment-Wise Attention #2

jlamprou opened this issue Apr 13, 2024 · 11 comments

Comments

@jlamprou
Copy link

From what i understand from the paper the sequence is segmented in smaller segments and fed to the attention layers. Is this not implemented yet in this implementation?

@Beomi
Copy link
Owner

Beomi commented Apr 13, 2024

exactly! I'm currently implementing part-by-part, so yes, it's not implemented yet :)

@jlamprou
Copy link
Author

I've been working on an implementation since the paper release and the only part im having problems with is the segment-wise part,using your implementation i segment the input and feed each segment to the self.attn with a for loop but on the second segment i get a mismatch on memory_output = memory_output / norm_term_expanded, at dimension 3 the memory_output is = head_dim but norm_term_expanded=head_dim*num_heads, so i dont know if my logic of segmentation is wrong or your Infi-attention implementation doesnt account for the accumulation of segments. I would be grateful if you have any tips

@zzr-idam
Copy link

Although we attempted the segmentation approach, the inference of the model is very slow, any suggestions?

@jlamprou
Copy link
Author

@Beomi I think that the segmentation is not supposed to happen inside the attention but before passing the inputs to the whole transformer block. Passing the whole input on the attention still requires to load the whole sequence in the VRAM .If we assume a Huggingface like model I'd say either on the decoder layer class or the model class.
Screenshot_20240414-160553_Brave

@Beomi
Copy link
Owner

Beomi commented Apr 14, 2024

@jlamprou You're right, maybe I have to edit decoder layer class, thanks for the note.

@jlamprou
Copy link
Author

@Beomi more likely the Model class, per my understanding we segment the input and we feed each segment to the decoder layers. What i'm not so sure about is how do we manage the compressed memory during the backward pass. I don't think we need gradients for the compressed memory so we should probably not directly assign self.norm_term and self.memory on the memory update but create new variable and then assign to self.memory, self.norm_term with either detach or torch.no_grad().

@Beomi
Copy link
Owner

Beomi commented Apr 14, 2024

@jlamprou I've been reconsidering your point that "Passing the entire input to the attention mechanism still requires loading the whole sequence into the VRAM." However, I believe that regardless of the method chosen, we end up loading all the input into VRAM eventually, and this could be O(N) (where N is the input length). The key issue, though, is that the paper aims to reduce memory usage in the quadratic component, which is the usual size due to the self attention. Thus, even when we segment within the attention loop, the overall memory size of the input sequence may be larger, but it has to reside somewhere—either on the CPU or GPU. Therefore, worrying about linear incremental memory usage isn't as crucial, since the vram usage part of the attention is fixed. How do you think?

@Beomi
Copy link
Owner

Beomi commented Apr 14, 2024

@zzr-idam In the published paper, they mentioned that they used "in this work we parameterize the memory with an associative matrix / cast the memory update and retrieval process as linear attention mechanism / we adopt the update rule
and retrieval mechanism by Katharopoulos et al. (2020) mainly due to its simplicity and competitive performance", so they might used Katharopoulos et al. (2020)(https://arxiv.org/pdf/2006.16236.pdf)

This repo is not implemented that paper's method yet, and I think that's the reason for slow inference.

@jlamprou
Copy link
Author

jlamprou commented Apr 14, 2024

@Beomi I'm testing right now both ways of implementing the segmentation. You are right based on my tests, the VRAM usage difference is small, with the segmentation inside the Attention consuming just about 1GB extra but with better throughput. So probably its best to keep the current implementation. The actually weird thing is that classic SDPA attention(as is from the original huggingface implementation) consumes the same amount of VRAM too, no segmenting or anything... We should probably take a look at this Memformers - Pytorch which implements a recurrent trainer .Maybe the segmentation shouldn't happen in the model at all, but in the training loop? The paper states: "We set the Infini-attention segment length N to 2048 for all attention layers and the input sequence length to 32768 for training. This allows the Infini-attention to unroll over 16 steps w.r.t its compressive memory states." which could mean training steps.

@Beomi
Copy link
Owner

Beomi commented Apr 17, 2024

@jlamprou Hi, I think its time to consider open both options toward end-users let select which way would be beneficial.

As you said before, the attention itself does not have a small fraction of memory but other data input processing such as MLP layer or even embedding increases vram usage, which makes hard to get a bigger block size.

In my experience(training code), vram usage required almost same as original implementation, so maybe your implementation direction would be more helpful in terms of vram usage.

or, maybe there would be a room for make it like an adapter(PEFT style)? How do you think?

@jlamprou
Copy link
Author

@Beomi I run some tests to check the validity of segmenting on the training loop, I tested the accuracy at every batch using the concat of logits and labels to check if the accuracy on the total sequence length is improving during training and once the learnable beta got some data we got the same accuracy rate with normal SDPA attention. Check the implementation on the my repo repo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants