In [1]:
!pip install local-attention



In [54]:
import torch
from local_attention import LocalAttention
from torch import Tensor, nn
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

In [3]:
batch =torch.rand(size=(256, 207, 361))

In [4]:

q = torch.randn(2, 8, 2048, 64)
k = torch.randn(2, 8, 2048, 64)
v = torch.randn(2, 8, 2048, 64)

attn = LocalAttention(
    dim = 64,                # dimension of each head (you need to pass this in for relative positional encoding)
    window_size = 512,       # window size. 512 is optimal, but 256 or 128 yields good enough results
    causal = True,           # auto-regressive or not
    look_backward = 1,       # each window looks at the window before
    look_forward = 0,        # for non-auto-regressive case, will default to 1, so each window looks at the window before and after it
    dropout = 0.1,           # post-attention dropout
    exact_windowsize = False # if this is set to true, in the causal setting, each query will see at maximum the number of keys equal to the window size
)

mask = torch.ones(2, 2048).bool()
out = attn(q, k, v, mask = mask) # (2, 8, 2048, 64)

In [5]:
from local_attention import LocalTransformer

In [6]:
model = LocalTransformer(
    num_tokens = 256,
    dim = 512,
    depth = 6,
    max_seq_len = 8192,
    causal = True,
    local_attn_window_size = 256
)

In [7]:
from transformers import LongformerConfig, LongformerModel

In [9]:
LongformerConfig

transformers.models.longformer.configuration_longformer.LongformerConfig

In [10]:
configuration = LongformerConfig()

In [11]:
configuration

LongformerConfig {
  "attention_probs_dropout_prob": 0.1,
  "attention_window": 512,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "onnx_export": false,
  "pad_token_id": 1,
  "sep_token_id": 2,
  "transformers_version": "4.36.2",
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [12]:
model = LongformerModel(configuration)

In [67]:
model.dtype

torch.float32

In [14]:
from transformers.models.longformer.modeling_longformer import LongformerLayer
from transformers.models.longformer.configuration_longformer import LongformerConfig

In [15]:
dict1 ={
  "attention_probs_dropout_prob": 0.1,
  "attention_window": [
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "bos_token_id": 0,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 768,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "onnx_export": False,
  "pad_token_id": 1,
  "sep_token_id": 2,
  "transformers_version": "4.36.2",
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [16]:
conf =LongformerConfig(attention_window = 36, hidden_size =270, intermediate_size = 1024,num_hidden_layers =4, num_attention_heads =8, max_position_embeddings=36)

In [17]:
long_layer =LongformerLayer(configuration)

In [65]:
long_layer.dtype

AttributeError: 'LongformerLayer' object has no attribute 'dtype'

In [18]:
batch =torch.rand(size=(256,1024,768))

In [19]:
batch.size()

torch.Size([256, 1024, 768])

In [46]:
attention_mask =torch.ones(batch.size()[:-1])

In [70]:
 def get_extended_attention_mask(
        attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None
    ) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (`Tuple[int]`):
                The shape of the input to the model.

        Returns:
            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
        """
        dtype = model.dtype
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            
            extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and the dtype's smallest value for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
        return extended_attention_mask

In [77]:
0*torch.finfo(model.dtype).min

-0.0

In [71]:
extended_att =get_extended_attention_mask(attention_mask=attention_mask, input_shape=batch.size()[:-1])

In [75]:
extended_att[
            :, 0, 0, :
        ]> 0

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [47]:
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask < 0
is_global_attn = is_index_global_attn.flatten().any().item()
output_attentions=False,

In [48]:
is_index_global_attn.shape

torch.Size([256, 1024])

In [49]:
is_global_attn

False

In [50]:
layer_outputs = long_layer(
                    batch,
                    attention_mask=attention_mask,
                    layer_head_mask=None,
                    is_index_masked=is_index_masked,
                    is_index_global_attn=is_index_global_attn,
                    is_global_attn=is_global_attn,
                    output_attentions=output_attentions,
                )

In [None]:
len(layer_outputs)

In [118]:
layer_outputs[0].shape

torch.Size([256, 1024, 768])

In [85]:
layer_outputs[1].shape

torch.Size([256, 512, 12, 1025])

In [86]:
layer_outputs[2].shape

torch.Size([256, 12, 512, 512])

In [25]:
configuration

LongformerConfig {
  "attention_probs_dropout_prob": 0.1,
  "attention_window": [
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512,
    512
  ],
  "bos_token_id": 0,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "onnx_export": false,
  "pad_token_id": 1,
  "sep_token_id": 2,
  "transformers_version": "4.36.2",
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [67]:
from transformers import RobertaTokenizer

In [68]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

In [69]:
SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000)  # long input document

input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0) 

Token indices sequence length is longer than the specified maximum sequence length for this model (4002 > 512). Running this sequence through the model will result in indexing errors


In [71]:
input_ids.shape

torch.Size([1, 4002])