In [3]:
import torch
import torch.nn as nn
import time
import math

In [4]:
def make_window_dilation_pairs(alpha, sequence_length):
    i = 1
    pairs = []
    while i*4 <= sequence_length:
        pairs.append((i*4, i)) # window_size, dilation_rate
        i *= alpha
    return pairs 

def create_dilated_mask(row_dim, col_dim, dilation_rate, head_index=0, offset=True):
    mask = torch.zeros(row_dim, col_dim)
    start = (head_index % dilation_rate) if offset else 0
    for i in range(start, row_dim, dilation_rate):
        for j in range(start, col_dim, dilation_rate):
            # if i >= j:
            mask[i, j] = 1
    return mask

def sparseToDense(sparse_tensor, dilation_rate, head_index=0, offset=True):
    leading_dims = sparse_tensor.shape[:-2]
    s_r, s_c = sparse_tensor.shape[-2], sparse_tensor.shape[-1]
    d_r, d_c = s_r // dilation_rate, s_c // dilation_rate
    dense_tensor = torch.zeros(*leading_dims, d_r, d_c, device=sparse_tensor.device)
    
    start = (head_index % dilation_rate) if offset else 0
    for i in range(d_r):
        for j in range(d_c):
            dense_tensor[..., i, j] = sparse_tensor[..., start + i * dilation_rate, start + j * dilation_rate]
    return dense_tensor

def denseToSparse(dense_tensor, dilation_rate, head_index=0, offset=True):
    leading_dims = dense_tensor.shape[:-2]
    d_r, d_c = dense_tensor.shape[-2], dense_tensor.shape[-1]
    s_r, s_c = d_r * dilation_rate, d_c * dilation_rate
    sparse_tensor = torch.zeros(*leading_dims, s_r, s_c, device=dense_tensor.device)
    
    start = (head_index % dilation_rate) if offset else 0
    for i in range(d_r):
        for j in range(d_c):
            sparse_tensor[..., start + i * dilation_rate, start + j * dilation_rate] = dense_tensor[..., i, j]
    return sparse_tensor

In [5]:
class MixedDilatedAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.alpha = config.alpha
        self.wr_pairs = [(4, 1), (8, 2), (16, 4), (32, 8)] # make_window_dilation_pairs(alpha=self.alpha, sequence_length=T)

    # attention within a window
    def dilated_attention_window(self, partial_q, partial_k, partial_v, window_size, dilation_rate, dropout_p=0.0, is_causal=False):
        head_index, window_size, hidden_dim = partial_q.size(-3), partial_q.size(-2), partial_k.size(-1)
        scale_factor = 1 / math.sqrt(hidden_dim)
        # attn_bias = torch.zeros(window_size, window_size, dtype=partial_q.dtype)
    
        # generate and apply masks to q, k, and v
        mask = create_dilated_mask(window_size, hidden_dim, dilation_rate, head_index, offset=True)
        masked_q = partial_q * mask
        masked_k = partial_k * mask
        masked_v = partial_v * mask
        
        attn_weight = torch.matmul(masked_q, masked_k.transpose(-2, -1)) * scale_factor
        
        # Apply causal mask if is_causal is True
        if is_causal:
            causal_mask = torch.tril(torch.ones(window_size, window_size, dtype=torch.bool))
            # attn_bias.masked_fill_(~causal_mask, float("-inf") )
            attn_weight = attn_weight.masked_fill_(~causal_mask, float("-inf") )
        
        
        # print(attn_weight)
        # attn_weight = sparseToDense(attn_weight, dilation_rate, head_index)
        attn_weight = torch.softmax(attn_weight, dim=-1)
        # attn_weight = denseToSparse(attn_weight, dilation_rate, head_index)
        attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
        # print(attn_weight[0][0])
        
        output_hat = attn_weight @ masked_v
        output_hat = output_hat * mask # output masking rule
        num_row = int(attn_weight.sum(dim=-1).sum().item()) # row that has some values other than zeros
        return output_hat, attn_weight, num_row

    def forward(self, x):
        start = time.time()
        B, T, C = x.size() # batch, seq_len, embedding dim (from nanogpt)
        head_dim = C // self.n_head
        qkv = self.c_attn(x)
        q, k, v = qkv.chunk(3, dim=-1)
        q, k, v = [tensor.view(*tensor.shape[:-1], self.n_head, -1).transpose(-3, -2) for tensor in (q, k, v)]
        
        y = torch.zeros_like(x)
        denominator = []

        for window_size, dilation_rate in self.wr_pairs: # multiple segment - dilation pairs
            partial_denominator = 0
            num_windows = T // window_size
            concated_output = torch.zeros_like(x)
            
            # print(num_windows)
            for i in range(num_windows): # parallel segment
                start = i * window_size
                end = start + window_size
                
                # Slice out the window for q, k, v
                partial_q = q[:, :, start:end, :]  # (B, nh, window_size, hs)
                partial_k = k[:, :, start:end, :]  # (B, nh, window_size, hs)
                partial_v = v[:, :, start:end, :]  # (B, nh, window_size, hs)
                window_output, attn_weight, num_row = self.dilated_attention_window(
                    partial_q, partial_k, partial_v, window_size, dilation_rate, is_causal=True
                )

                # Reshape window_output to (B, window_size, C) for placement in concated_output
                window_output = window_output.transpose(1, 2).reshape(B, window_size, C)
                concated_output[:, start:end, :] = window_output
                partial_denominator += num_row
            
            denominator.append(partial_denominator)
            y += concated_output * partial_denominator
  
        y /= sum(denominator)
        
        att_weights, updated_kv_cache = None, None 
        end = time.time()
        print(f"Attention time: {1000*(end - start):.4f} ms")
        
        return y, att_weights, updated_kv_cache
    
class Config:
    # block_size: int = 16 # max seq_len
    n_embd = 4
    n_head = 1
    alpha = 2

config = Config()
sequence_length = 32
hidden_dim = config.n_embd

x = torch.randn(1, sequence_length, hidden_dim)  # Batch size of 1
attention_layer = MixedDilatedAttention(config)
output = attention_layer(x)
# print(output)

Attention time: 1732874600559.0066 ms


In [9]:
# https://github.com/kyegomez/LongNet/blob/master/long_net/attention.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class RelativePositionBias(nn.Module):
    def __init__(
        self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=12
    ):
        super().__init__()
        self.bidirectional = bidirectional
        self.num_buckets = num_buckets
        self.max_distance = max_distance
        self.n_heads = n_heads
        self.relative_attention_bias = nn.Embedding(
            self.num_buckets, self.n_heads
        )

    @staticmethod
    def _relative_position_bucket(
        relative_position, bidirectional=True, num_buckets=32, max_distance=128
    ):
        ret = 0
        n = -relative_position
        if bidirectional:
            num_buckets //= 2
            ret += (n < 0).to(torch.long) * num_buckets
            n = torch.abs(n)
        else:
            n = torch.max(n, torch.zeros_like(n))

        max_exact = num_buckets // 2
        is_small = n < max_exact

        val_if_large = max_exact + (
            torch.log(n.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).to(torch.long)
        val_if_large = torch.min(
            val_if_large, torch.full_like(val_if_large, num_buckets - 1)
        )

        ret += torch.where(is_small, n, val_if_large)
        return ret

    def compute_bias(self, qlen, klen, step=None):
        step = 0 if step is None else step
        context_position = torch.arange(
            step,
            step + qlen,
            dtype=torch.long,
            device=self.relative_attention_bias.weight.device,
        )[:, None]
        memory_position = torch.arange(
            klen,
            dtype=torch.long,
            device=self.relative_attention_bias.weight.device,
        )[None, :]
        relative_position = (
            memory_position - context_position
        )  # shape (qlen, klen)

        rp_bucket = self._relative_position_bucket(
            relative_position,  # shape (qlen, klen)
            bidirectional=self.bidirectional,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance,
        )
        rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
        values = self.relative_attention_bias(
            rp_bucket
        )  # shape (qlen, klen, heads)
        values = values.permute([2, 0, 1]).unsqueeze(
            0
        )  # shape (1, heads, qlen, klen)
        return values

    def forward(self, batch_size, qlen, klen, step=None):
        # shape (batch * heads, qlen, klen)
        return (
            self.compute_bias(qlen, klen, step)
            .repeat(batch_size, 1, 1, 1)
            .view(-1, qlen, klen)
        )



def fixed_pos_embedding(x):
    seq_len, dim = x.shape
    inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim))
    sinusoid_inp = torch.einsum(
        "i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq
    ).to(x)
    return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)


def rotate_every_two(x):
    x1 = x[:, :, ::2]
    x2 = x[:, :, 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(
        -2
    )  # in einsum notation: rearrange(x, '... d j -> ... (d j)')\


def duplicate_interleave(m):
    """
    A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
    """
    dim0 = m.shape[0]
    m = m.view(-1, 1)  # flatten the matrix
    m = m.repeat(1, 2)  # repeat all elements into the 2nd dimension
    m = m.view(dim0, -1)  # reshape into a matrix, interleaving the copy
    return m


def apply_rotary_pos_emb(x, sin, cos, scale=1):
    sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos))
    # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
    return (x * cos) + (rotate_every_two(x) * sin)


class XPOS(nn.Module):
    def __init__(self, head_dim, scale_base=512):
        super().__init__()
        self.head_dim = head_dim
        self.scale_base = scale_base
        self.register_buffer(
            "scale",
            (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim),
        )

    def forward(self, x, offset=0, downscale=False):
        length = x.shape[1]
        min_pos = -(length + offset) // 2
        max_pos = length + offset + min_pos
        scale = (
            self.scale
            ** torch.arange(min_pos, max_pos, 1)
            .to(self.scale)
            .div(self.scale_base)[:, None]
        )
        sin, cos = fixed_pos_embedding(scale)

        if scale.shape[0] > length:
            scale = scale[-length:]
            sin = sin[-length:]
            cos = cos[-length:]

        if downscale:
            scale = 1 / scale

        x = apply_rotary_pos_emb(x, sin, cos, scale)
        return x



# add alibi, qk layer norm, one write head, multihway,
class DilatedAttention2(nn.Module):
    """
    Dilated Attention Module.

    Arguments:
        dim: The dimension of the attention layers.
        heads: The number of attention heads.
        dilation_rate: The dilation rate for dilated attention.
        segment_size: The segment size for dilated attention.
        dropout (optional): The dropout probability. Default: 0.0
        causal (optional): If set to True, the attention mechanism is causal. Default: False
        use_xpos (optional): If set to True, xpos is used for positional encoding. Default: False
        use_rel_pos_bias (optional): If set to True, relative position bias is used in the attention mechanism. Default: False

    Usage:
        The `DilatedAttention` class can be used as a module for neural networks and is especially suited for transformer architectures.

        Example:
            attention = DilatedAttention(dim=512, heads=8, dilation_rate=2, segment_size=64, use_xpos=True, use_rel_pos_bias=True)
            output = attention(input_tensor)

        This will return the output tensor after applying dilated attention. The `use_xpos` and `use_rel_pos_bias` parameters allow for switching on positional encoding and relative positional bias respectively.
    """

    def __init__(
        self,
        dim: int,
        heads: int,
        dilation_rate: int,
        segment_size: int,
        # dropout: float = 0.0,
        causal: bool = False,
        use_xpos: bool = False,
        use_rel_pos_bias: bool = False,
        qk_norm: bool = False,
        dtype: torch.dtype = torch.float16,
        device: str = "cuda:0",
    ) -> None:
        super(DilatedAttention2, self).__init__()
        self.dim = dim
        self.heads = heads
        self.dilation_rate = dilation_rate
        self.segment_size = segment_size
        # self.dropout = nn.Dropout(dropout)
        self.causal = causal
        self.use_xpos = use_xpos
        self.use_rel_pos_bias = use_rel_pos_bias
        self.qk_norm = qk_norm
        self.dtype = dtype
        self.device = device

        # self.attention = FlashAttention(causal=self.causal, dropout=dropout).to(device)

        # if use_xpos:
        #     self.xpos = XPOS(head_dim=dim // heads)
        # if use_rel_pos_bias:
        #     self.relative_bias = RelativePositionBias(
        #         num_buckets=32, max_distance=128, n_heads=heads
        #     )

        self.norm = nn.LayerNorm(dim)

        # head offsets
        self.head_offsets = nn.Parameter(torch.randn(heads, dim))

        # Linear Projections
        self.proj_q = nn.Linear(dim, dim)
        self.proj_k = nn.Linear(dim, dim)
        self.proj_v = nn.Linear(dim, dim)

    def get_mask(self, i, j):
        """i = row, j=column"""
        return torch.ones((i, j), device=self.device, dtype=torch.bool).triu(
            j - i + 2
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the DilatedAttention module.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor.
        """
        start = time.time()
        
        batch_size, seq_len, _ = x.shape
        padding_len = -seq_len % self.segment_size
        x = F.pad(x, (0, 0, 0, padding_len))
        seq_len = seq_len + padding_len

        # if self.use_xpos:
        #     x = self.xpos(x)
        print('Before split and sparsify', x.shape)

        # Split and sparsify
        x = x.view(batch_size, -1, self.segment_size, self.dim)
        x = x[:, :, :: self.dilation_rate, :]
        
        print('After split and sparsify', x.shape)

        # qk_norm
        if self.qk_norm:
            q, k, v = map(
                self.norm, (self.proj_q(x), self.proj_k(x), self.proj_v(x))
            )
        else:
            q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)

        # Perform attention
        # attn_output = self.attention(q, k, v)
        attn_output = F.scaled_dot_product_attention(q, k, v)

        # if use rel pos => apply relative positioning bias
        # if self.use_rel_pos_bias:
        #     attn_output += self.relative_bias(
        #         batch_size, attn_output.size(1), attn_output.size(1)
        #     )

        # if causal create a mask and apply to the output
        if self.causal:
            mask = self.get_mask(attn_output.size(1), attn_output.size(1))

            attn_output = attn_output.masked_fill(mask, float("-inf"))

        # apply dropout
        # attn_output = self.dropout(attn_output)
        # Scatter and concatenate
        attn_output = attn_output.reshape(batch_size, -1, self.dim)

        end = time.time()
        print(f"Attention time: {1000*(end - start):.4f} ms")

        return attn_output

sequence_length = 32768
hidden_dim = 768
x = torch.randn(1, sequence_length, hidden_dim)  # Batch size of 1
attention_layer = DilatedAttention2(
	dim=768, heads=12, dilation_rate=6, segment_size=16384, use_xpos=False, use_rel_pos_bias=False, qk_norm=False
)
output = attention_layer(x)
# print(x)

Before split and sparsify torch.Size([1, 32768, 768])
After split and sparsify torch.Size([1, 2, 2731, 768])
Attention time: 95.6528 ms
