You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for your great work. Here I want to ask why fft(k) is not fused into the kernel, is it a performance issue?
I mean why is it implemented as follows:
def fftconv_fast(u, k, D, dropout_mask):
"""Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout
"""
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size)
out = fftconv_fwd(u, k_f, D, dropout_mask, fft_size)
return out
instead of:
def fftconv_fast(u, k, D, dropout_mask):
"""Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout
"""
seqlen = u.shape[-1]
fft_size = 2 * seqlen
out = fftconv_fwd(u, k, D, dropout_mask, fft_size)
return out
The text was updated successfully, but these errors were encountered:
The shape of k is (H, L) so the IO cost isn’t too bad to compute it
outside. The shape of u is (B, H, L) so that dominates the IO cost. We also
parallelize by B and H so it would naively require extra computation.
On Wed, Mar 8, 2023 at 11:54 AM Doraemonzzz ***@***.***> wrote:
Thanks for your great work. Here I want to ask why fft(k) is not fused
into the kernel, is it a performance issue?
I mean why is it implemented as follows:
def fftconv_fast(u, k, D, dropout_mask):
"""Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout
"""
seqlen = u.shape[-1]
fft_size = 2 * seqlen
k_f = torch.fft.rfft(k, n=fft_size)
out = fftconv_fwd(u, k_f, D, dropout_mask, fft_size)
return out
instead of:
def fftconv_fast(u, k, D, dropout_mask):
"""Fuse padding + rfft + pointwise mult + ifft + multiply with D + gelu + dropout
"""
seqlen = u.shape[-1]
fft_size = 2 * seqlen
out = fftconv_fwd(u, k, D, dropout_mask, fft_size)
return out
—
Reply to this email directly, view it on GitHub
<#19>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIITMIV5Z6B7BQAZJDTTW3CTR3ANCNFSM6AAAAAAVT6CM34>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
Thanks for your great work. Here I want to ask why fft(k) is not fused into the kernel, is it a performance issue?
I mean why is it implemented as follows:
instead of:
The text was updated successfully, but these errors were encountered: