In [18]:
import torch
import math
import time

In [19]:

def attention(query, key, value, mask):
    d_k = query.shape[-1]
    # Just apply the formula from the paper
    # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
    attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        # Write a very low value (indicating -inf) to the positions where mask == 0
        attention_scores.masked_fill_(mask == 0, -1e9)
    attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
    return (attention_scores @ value)

def dilated_attention_singlehead(x, w_seq, r_seq):
    seq_length = x.shape[2]
    k = len(w_seq) # The number of combinations of segment lengths (w) and dilations (r)

    all_O = []
    all_s_i = []

    for param_index in range(k):
        w = w_seq[param_index]
        r = r_seq[param_index]

        all_O_hats = []

        for i in range(seq_length // w):
            segment_content = x[:, i * w : (i + 1) * w, :]

            # Apply the dilation, by zeroing out every rth column
            mask = torch.zeros_like(segment_content)
            mask[:, ::r, :] = 1  
            segment_content = segment_content * mask

            query_tilde = segment_content
            key_tilde = segment_content
            value_tilde = segment_content
            
            O_tilde = query_tilde @ key_tilde.transpose(-1, -2)
            # Zero out out everything above the diagonal to make it causal
            mask = torch.tril(torch.ones_like(O_tilde), diagonal=0)
            O_tilde = O_tilde * mask
            O_tilde = torch.softmax(O_tilde, dim=-1)
            O_tilde = O_tilde @ value_tilde

            # Make all the columns for which column_index % r != 0 equal to 0
            # This is the same as the mask used in the paper
            mask = torch.zeros_like(O_tilde)
            mask[:, :, ::r] = 1
            o_hat = O_tilde * mask

            all_O_hats.append(o_hat)


        O = torch.cat(all_O_hats, dim=1)
        # According to the paper: "s_i" is the denominator of the attention softmax for O
        s_i = torch.sum(torch.exp(O)).item()
        all_O.append(O)
        all_s_i.append(s_i)

    sum_s_i = sum(all_s_i)
    alpha_i = [s_i / sum_s_i for s_i in all_s_i]
    sum_terms = [alpha_i[i] * all_O[i] for i in range(len(all_O))]
    # Sum all the tensors in sum_terms
    return sum(sum_terms)

In [21]:
#Shape: (batch_size, num_heads, seq_len, dim)


# embedding size
d_model = 1024

w_seq = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
r_seq = [1, 2, 4, 8, 16, 32, 64, 128, 256]

# for test_index in range(len(seq_multipliers)):
seq_len = 1024
x = torch.rand(32, 1, seq_len, d_model).to(torch.float)

start_time = time.time()
# calculate the attention according to "Attention is all you need"
causal_mask = torch.tril(torch.ones(seq_len, seq_len), diagonal=0)
attn_normal = attention(x, x, x, causal_mask)
end_time = time.time()
print(f"Time taken for normal attention: {end_time - start_time:.2f} seconds")

start_time = time.time()
# calculate the dilated attention
attn_dilated = dilated_attention_singlehead(x, w_seq, r_seq)
assert attn_dilated.shape == x.shape
end_time = time.time()
print(f"Time taken for dilated attention: {end_time - start_time:.2f} seconds")



Time taken for normal attention: 0.27 seconds
Time taken for dilated attention: 19.45 seconds
