In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math

from causal_self_attention import CausalSelfAttention

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [55]:
class DilatedAttention2(nn.Module):
    def __init__(self, config, segment_size, dilation_rate):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        self.block_size = config.block_size
        
        self.segment_size = segment_size
        self.dilation_rate = dilation_rate
        
        # Linear Projections
        self.proj_q = nn.Linear(config.n_embd, config.n_embd, bias=False).to(device)
        self.proj_k = nn.Linear(config.n_embd, config.n_embd, bias=False).to(device)
        self.proj_v = nn.Linear(config.n_embd, config.n_embd, bias=False).to(device)
        self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=False).to(device)
        
        self.norm = nn.LayerNorm(config.n_embd).to(device)
        
    def forward(self, x, kv_cache=None, use_cache=False, output_attentions=False):
        B, N, D = x.size()
        device = x.device
        
        assert N % self.segment_size == 0, f"N: {N}, segment_size: {self.segment_size}"
        assert self.segment_size % self.dilation_rate == 0
        
        # Sparsify
        x = x.view(B, N // self.segment_size, self.segment_size, D)
        x = x[:, :, :: self.dilation_rate, :]
        q, k, v = map(self.norm, (self.proj_q(x), self.proj_k(x), self.proj_v(x))) # q,k,v: torch.Size([B, num_segments, segment_size // dilation_rate, D])
        
        # TODO: Implement cache
        if use_cache and kv_cache is not None:
            k = torch.cat([kv_cache["k"], k], dim=-2) # Append new keys
            v = torch.cat([kv_cache["v"], v], dim=-2) # Append new values

        # Update kv_cache if caching is enabled
        updated_kv_cache = {"k": k, "v": v} if use_cache else None
        
        # TODO: Implement shifting positions
        
        # All gather
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # torch.Size([B, num_segments, segment_size // dilation_rate, D])
        y = y.reshape(B, -1, D) # torch.Size([B, N // dilation_rate, D])
        y_full = torch.zeros(B, N, D, device=y.device, dtype=y.dtype)
        y_full[:, ::self.dilation_rate, :] = y
        y_full = self.out_proj(y_full)
        
        att_weights, updated_kv_cache = None, None
        
        return y_full, att_weights, updated_kv_cache

class Config:
    block_size = 32
    n_embd = 8
    n_head = 4
config = Config()
sequence_length = 32
hidden_dim = config.n_embd

x = torch.randn(1, sequence_length, hidden_dim).to(device)  # Batch size of 1
attention_layer = DilatedAttention2(config, segment_size=8, dilation_rate=2)
output = attention_layer(x)

# output[0]

In [56]:
class MixedDilatedAttention2(nn.Module):
    def __init__(self, config, wr_pairs):
        super().__init__()
        self.config = config
        self.wr_pairs = wr_pairs
        self.dilated_attn = nn.ModuleList()
        for segment_size, dilation_rate in self.wr_pairs:
            self.dilated_attn = nn.ModuleList(
                [DilatedAttention2(self.config, segment_size, dilation_rate) for segment_size, dilation_rate in self.wr_pairs]
            )

        from causal_self_attention import CausalSelfAttention
        self.self_attn = nn.ModuleList([CausalSelfAttention(config)])

    def forward(self, x, kv_cache=None, use_cache=False, output_attentions=False):
        N = x.size(1)
        output = None
        
        is_dilated = False
        for segment_size, _ in self.wr_pairs:
            if N % segment_size == 0:
                is_dilated = True
            else:
                is_dilated = False
                break
        
        if is_dilated:
            for block in self.dilated_attn:
                print("dilated_attention")
                block_output, _, _ = block(x)
        else:
            for block in self.self_attn:
                print("self_attn")
                block_output, _, _ = block(x)
        output = block_output if output is None else output + block_output
        
        att_weights, updated_kv_cache = None, None
        return output, att_weights, updated_kv_cache
class Config:
    block_size = 32768
    n_embd = 768
    n_head = 12
config = Config()
sequence_length = 32760
hidden_dim = config.n_embd

wr_pairs = [(2048, 1),(4096, 2),(8192,4),(16384,8),(32768,16)]

x = torch.randn(2, sequence_length, hidden_dim).to(device)  # Batch size of 1
attention_layer = MixedDilatedAttention2(config, wr_pairs)
output = attention_layer(x)

output[0].shape


self_attn


torch.Size([2, 32760, 768])