In [1]:
import torch 

In [147]:
def create_dilated_mask(row_dim, col_dim, dilation_rate, head_index=0, offset=True):
    mask = torch.zeros(row_dim, col_dim)
    start = (head_index % dilation_rate) if offset else 0
    for i in range(start, row_dim, dilation_rate):
        for j in range(start, col_dim, dilation_rate):
            # if i >= j:
            mask[i, j] = 1
    return mask

def sparseToDense(sparse_tensor, dilation_rate, head_index=0, offset=True):
    s_r, s_c = sparse_tensor.size()
    d_r, d_c = s_r // dilation_rate, s_c // dilation_rate 
    dense_tensor = torch.zeros(d_r, d_c)
    start = (head_index % dilation_rate) if offset else 0
    for i in range(d_r):
        for j in range(d_c):
            dense_tensor[i, j] = sparse_tensor[start+i*dilation_rate][start+j*dilation_rate]
    return dense_tensor

def denseToSparse(dense_tensor, dilation_rate, head_index=0, offset=True):
    d_r, d_c = dense_tensor.size()
    s_r, s_c = d_r * dilation_rate, d_c * dilation_rate
    sparse_tensor = torch.zeros(s_r, s_c)
    start = (head_index % dilation_rate) if offset else 0
    for i in range(d_r):
        for j in range(d_c):
            sparse_tensor[start + i * dilation_rate, start + j * dilation_rate] = dense_tensor[i, j]
    return sparse_tensor

# def create_dilated_mask(row_dim, col_dim, dilation_rate, head_index=0, offset=True): # paper-based
#     mask = torch.zeros(row_dim, col_dim)
#     start = (head_index % dilation_rate) if offset else 0
#     for i in range(start, row_dim, dilation_rate):
#         mask[i, :] = 1  # Select every `dilation_rate`-th row
#     return mask

In [152]:
import math
from torch.nn import functional as F

dilation_rate = 2
window_size = 8
hidden_dim = 16
offset = True
head_index = 0

partial_q = torch.randn(window_size, hidden_dim)
partial_k = torch.randn(window_size, hidden_dim)
partial_v = torch.randn(window_size, hidden_dim)

# attention within a window
def dilated_attention_window(partial_q, partial_k, partial_v, dilation_rate, head_index=0, dropout_p=0.0, is_causal=False):
    w, d = partial_q.size(-2), partial_k.size(-1)
    scale_factor = 1 / math.sqrt(d)
    attn_bias = torch.zeros(w, w, dtype=partial_q.dtype)
 
    # generate and apply masks to q, k, and v
    mask = create_dilated_mask(w, d, dilation_rate, head_index, offset=True)
    masked_q = partial_q * mask
    masked_k = partial_k * mask
    masked_v = partial_v * mask
    
    # Apply causal mask if is_causal is True
    if is_causal:
        causal_mask = torch.tril(torch.ones(w, w, dtype=torch.bool))
        attn_bias.masked_fill_(~causal_mask, float("-inf") )
    
    attn_weight = torch.matmul(masked_q, masked_k.transpose(-2, -1)) * scale_factor + attn_bias
    # print(attn_weight)
    attn_weight = sparseToDense(attn_weight, dilation_rate, head_index)
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = denseToSparse(attn_weight, dilation_rate, head_index)
    print(attn_weight)    
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    
    output_hat = attn_weight @ masked_v
    output_hat = output_hat * mask # output masking rule
    return output_hat

dilated_attention_window(partial_q, partial_k, partial_v, dilation_rate, head_index, is_causal=True)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5161, 0.0000, 0.4839, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2660, 0.0000, 0.4631, 0.0000, 0.2709, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4051, 0.0000, 0.0932, 0.0000, 0.3941, 0.0000, 0.1076, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])


tensor([[ 0.0317,  0.0000, -0.0871,  0.0000, -1.8382,  0.0000, -1.4081,  0.0000,
         -2.5893,  0.0000, -2.0640,  0.0000, -0.9550,  0.0000, -0.1931,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0381,  0.0000,  0.6472,  0.0000,  0.0804,  0.0000, -0.5493,  0.0000,
         -0.8101,  0.0000, -1.9289,  0.0000, -1.4728,  0.0000, -0.7911,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.1882,  0.0000,  0.1388,  0.0000,  0.4842,  0.0000, -0.1644,  0.0000,
         -0.1446,  0.0000, -1.4414,  0.0000, -1.5465,  0.0000, -1.1938,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.1758,  0.0