## `PyTorch` implementation of the `LongNet` dilated attention
Dilated attention is a trick for greatly expanding the context legth of transformers.
It slices the input in segments and performs "dilated" selections for each multi-head attention block, 
in order to create a hierarchy of heads that attend to progressively more context while not incurring the quadradic memory 
and computational overhead of the self-attention layers.

## TODO:
* finish writing longnet transformer 
* make some first tests on lim. data
* Use for distillation

In [1]:

from src.dilated_attention import MultiHeadDilatedAttention, DilatedTransformerBlock

In [2]:
device = 'cuda:0'
nheads = 16
dilation_schedule = [1,  1, 2  ,2,  4]
segment_schedule =  [128,256,512,512,512]
model = DilatedTransformerBlock(
    segment_schedule = segment_schedule,
    dilation_schedule = dilation_schedule,
    device = device
)

In [3]:
import torch
class LongNetEncoder(torch.nn.Module):
    def __init__(
        self,
        n_layers = 8,
        dilation_schedule = dilation_schedule, 
        segment_schedule = segment_schedule, 
        emb_params = {'num_embeddings' : 1024, 'embedding_dim' : 768},
        device = None
    ):
        super(LongNetEncoder, self).__init__()
        self.dilation_schedule = dilation_schedule 
        
        self.segment_schedule = segment_schedule
        self.n_layers = 16
        self._is_built = False
        self.device = device 
        self.emb_params = emb_params
    
    def _build(self, x_in):
        self.blocks = torch.nn.ModuleList(
            [
                DilatedTransformerBlock(
                    dilation_schedule = self.dilation_schedule,
                    segment_schedule = self.segment_schedule,
                    device = self.device
                ) for i in range(self.n_layers)
            ]
        )
        self.embedding = torch.nn.Embedding(**self.emb_params, device= self.device)
        self._is_built = True
        
    def forward(self, x_in):
        """
        x_in is an iterable with integers.
        """
        if not self._is_built:
            self._build(x_in)
            
        x_curr = self.embedding(x_in)
        for m in self.blocks:
            x_curr = m(x_curr)
        
        return torch.nn.functional.softmax(x_curr, 1)
            



---
Instantiate a model and apply it to a long sequence

a 10-block, 32-head-per self-attention block transformer is applied to a 
sequence of length 4096 for embeddings of size 768

In [4]:
import numpy  as np
model = LongNetEncoder(device = device, n_layers = 8)

In [5]:
x_in_int = torch.from_numpy(1 * (np.random.randn(8,1024*4)>1)).to(device)
def _no_opt():
    model(x_in_int)

_no_opt()


---
### Benchmarking with an RTX2080.

In [6]:
%%timeit
_no_opt()

941 ms ± 2.28 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
@torch.compile
def _test_perf():
    model(x_in_int)

In [9]:
%%timeit
_test_perf()

1.31 ms ± 61.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
