## `PyTorch` implementation of the `LongNet` dilated attention
Dilated attention is a trick for greatly expanding the context legth of transformers.
It slices the input in segments and performs "dilated" selections for each multi-head attention block, 
in order to create a hierarchy of heads that attend to progressively more context while not incurring the quadradic memory 
and computational overhead of the self-attention layers.

## TODO:
* Check if I'm normalizing correctly
* Use for distillation

In [None]:
from src.dilated_attention import MultiHeadDilatedAttention

In [2]:
import numpy as np
import torch

seq_len = 1024 * 8
emb_size = 768


In [3]:
device = 'cuda:0'
x_in = torch.Tensor(np.random.randn(10, seq_len, emb_size)).to('cuda:0')
dilation_schedule = [1,1,2,4,8,16]
segment_sizes = [128,512,128,128,128]
mha = MultiHeadDilatedAttention(
    d_k = 256, dilation_schedule=dilation_schedule, segment_sizes = segment_sizes, device = device
)

---
### Benchmarking with an RTX2080.

In [6]:
@torch.compile
def _test_perf():
    return mha(x_in)

In [7]:
%%timeit
_test_perf()

1.82 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%timeit
mha(x_in)

60.8 ms ± 823 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
