In [5]:
import torch 

In [6]:
def create_dilated_mask(row_dim, col_dim, dilation_rate, head_index=0, offset=True):
    mask = torch.zeros(row_dim, col_dim)
    start = (head_index % dilation_rate) if offset else 0
    for i in range(start, row_dim, dilation_rate):
        for j in range(start, col_dim, dilation_rate):
            # if i >= j:
            mask[i, j] = 1
    return mask

def sparseToDense(sparse_tensor, dilation_rate, head_index=0, offset=True):
    s_r, s_c = sparse_tensor.size()
    d_r, d_c = s_r // dilation_rate, s_c // dilation_rate 
    dense_tensor = torch.zeros(d_r, d_c)
    start = (head_index % dilation_rate) if offset else 0
    for i in range(d_r):
        for j in range(d_c):
            dense_tensor[i, j] = sparse_tensor[start+i*dilation_rate][start+j*dilation_rate]
    return dense_tensor

def denseToSparse(dense_tensor, dilation_rate, head_index=0, offset=True):
    d_r, d_c = dense_tensor.size()
    s_r, s_c = d_r * dilation_rate, d_c * dilation_rate
    sparse_tensor = torch.zeros(s_r, s_c)
    start = (head_index % dilation_rate) if offset else 0
    for i in range(d_r):
        for j in range(d_c):
            sparse_tensor[start + i * dilation_rate, start + j * dilation_rate] = dense_tensor[i, j]
    return sparse_tensor

# def create_dilated_mask(row_dim, col_dim, dilation_rate, head_index=0, offset=True): # paper-based
#     mask = torch.zeros(row_dim, col_dim)
#     start = (head_index % dilation_rate) if offset else 0
#     for i in range(start, row_dim, dilation_rate):
#         mask[i, :] = 1  # Select every `dilation_rate`-th row
#     return mask

In [15]:
import math
from torch.nn import functional as F

dilation_rate = 2
window_size = 8
hidden_dim = 16
offset = True
head_index = 0

# attention within a window
def dilated_attention_window(partial_q, partial_k, partial_v, window_size, dilation_rate, head_index=0, dropout_p=0.0, is_causal=False):
    window_size, hidden_dim = partial_q.size(-2), partial_k.size(-1)
    scale_factor = 1 / math.sqrt(d)
    attn_bias = torch.zeros(window_size, window_size, dtype=partial_q.dtype)
 
    # generate and apply masks to q, k, and v
    mask = create_dilated_mask(window_size, hidden_dim, dilation_rate, head_index, offset=True)
    masked_q = partial_q * mask
    masked_k = partial_k * mask
    masked_v = partial_v * mask
    
    # Apply causal mask if is_causal is True
    if is_causal:
        causal_mask = torch.tril(torch.ones(window_size, window_size, dtype=torch.bool))
        attn_bias.masked_fill_(~causal_mask, float("-inf") )
    
    attn_weight = torch.matmul(masked_q, masked_k.transpose(-2, -1)) * scale_factor + attn_bias
    # print(attn_weight)
    attn_weight = sparseToDense(attn_weight, dilation_rate, head_index)
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = denseToSparse(attn_weight, dilation_rate, head_index)
    # print(attn_weight)    
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    
    output_hat = attn_weight @ masked_v
    output_hat = output_hat * mask # output masking rule
    num_row = int(attn_weight.sum(dim=-1).sum().item()) # row that has some values other than zeros
    return output_hat, num_row


In [30]:
def make_window_dilation_pairs(sequence_length=32, alpha=2):
    i = 1
    pairs = []
    while i*4 <= sequence_length:
        pairs.append((i*4, i)) # window_size, dilation_rate
        i *= alpha
    return pairs 

class MixedDilatedAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        batch_size, sequence_length, hidden_dim = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        output = torch.zeros(sequence_length, hidden_dim)
        denominator = []
        wr_pairs = make_window_dilation_pairs()

        for window_size, dilation_rate in wr_pairs: # multiple segment - dilation pairs
            partial_denominator = 0
            num_windows = sequence_length // window_size
            concated_output = torch.zeros(sequence_length, hidden_dim)

            for i in range(num_windows): # parallel segment
                start = i * window_size
                end = start + window_size
                
                window_output, num_row = dilated_attention_window(partial_q, partial_k, partial_v, window_size, dilation_rate, head_index, is_causal=True)
                concated_output[start:end] = window_output
                partial_denominator += num_row

            denominator.append(partial_denominator)
            output += concated_output * partial_denominator

        output /= sum(denominator)
        return output

x = torch.randn(sequence_length, hidden_dim)
output = dilated_attention(x)
print(output)

tensor([[-1.5223e-01, -6.3395e-01, -6.7170e-01,  8.3738e-01,  3.9974e-01,
          2.1473e-01,  9.1315e-01,  2.0024e-02,  1.1435e+00, -6.5006e-02,
          2.2309e-01,  1.8378e-01,  3.5100e-01,  1.2653e-01, -1.6455e-01,
         -3.6550e-01],
        [-2.6500e-01, -2.4508e-01, -3.3913e-01,  5.7264e-01, -1.2404e-01,
          1.8270e-01,  4.1402e-01, -3.1299e-02,  6.5652e-01, -5.9528e-01,
          2.7515e-01, -2.8199e-01,  1.1703e-01,  1.3926e-01, -2.6042e-01,
         -4.4480e-02],
        [-9.3077e-02,  3.4226e-01,  2.7646e-01,  1.7710e-01, -8.6060e-02,
          8.2148e-02, -3.7113e-01, -1.8295e-01,  9.1809e-01, -9.4049e-01,
          1.1124e-03, -3.8898e-01, -9.6351e-02, -2.7345e-02, -5.5050e-01,
          5.4099e-01],
        [-2.3615e-01, -2.0461e-01, -2.8226e-01,  5.7399e-01, -8.3917e-02,
          1.9291e-01,  3.6596e-01, -1.0659e-02,  5.4944e-01, -5.6472e-01,
          2.3974e-01, -2.5444e-01,  7.1538e-02,  1.1438e-01, -1.9084e-01,
         -4.7025e-02],
        [-4.6351e-01

In [None]:


output = torch.zeros(sequence_length, hidden_dim)
denominator = 0
for w, r in wr_pairs:
    concated_output, partial_denominator = construct_full_attention_matrix(sequence_length, window_size=w, dilation_rate=r)
    output += concated_output * denominator
    denominator += partial_denominator
output /= denominator

print(output.size())

NameError: name 'torch' is not defined