In [13]:
import torch
import math

In [14]:

def attention(query, key, value, mask):
    d_k = query.shape[-1]
    # Just apply the formula from the paper
    # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
    attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        # Write a very low value (indicating -inf) to the positions where mask == 0
        attention_scores.masked_fill_(mask == 0, -1e9)
    attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
    return (attention_scores @ value)

def dilated_attention_singlehead(x, w_seq, r_seq):
    seq_length = x.shape[1]
    k = len(w_seq) # The number of combinations of segment lengths (w) and dilations (r)

    all_O = []
    all_s_i = []

    for param_index in range(k):
        w = w_seq[param_index]
        r = r_seq[param_index]

        all_O_hats = []

        for i in range(seq_length // w):
            segment_content = x[:, i * w : (i + 1) * w, :]

            # Apply the dilation, by zeroing out every rth column
            mask = torch.zeros_like(segment_content)
            mask[:, ::r, :] = 1  
            segment_content = segment_content * mask

            query_tilde = segment_content
            key_tilde = segment_content
            value_tilde = segment_content
            
            O_tilde = query_tilde @ key_tilde.transpose(-1, -2)
            # Zero out out everything above the diagonal to make it causal
            mask = torch.tril(torch.ones_like(O_tilde), diagonal=0)
            O_tilde = O_tilde * mask
            O_tilde = torch.softmax(O_tilde, dim=-1)
            O_tilde = O_tilde @ value_tilde

            # Make all the columns for which column_index % r != 0 equal to 0
            # This is the same as the mask used in the paper
            mask = torch.zeros_like(O_tilde)
            mask[:, :, ::r] = 1
            o_hat = O_tilde * mask

            all_O_hats.append(o_hat)


        O = torch.cat(all_O_hats, dim=1)
        # According to the paper: "s_i" is the denominator of the attention softmax for O
        s_i = torch.sum(torch.exp(O)).item()
        all_O.append(O)
        all_s_i.append(s_i)

    sum_s_i = sum(all_s_i)
    alpha_i = [s_i / sum_s_i for s_i in all_s_i]
    sum_terms = [alpha_i[i] * all_O[i] for i in range(len(all_O))]
    # Sum all the tensors in sum_terms
    O = sum(sum_terms)
    return O

In [16]:
#Shape: (batch_size, num_heads, seq_len, dim)
seq_len = 32

x = torch.randn(32, 1, seq_len, 512)

causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

# calculate the attention
scores = attention(x, x, x, causal_mask)

x = torch.rand((32, seq_len, 512)).to(torch.float) # Input to the attention
# x = torch.rand(32, 16, 512).to(torch.float)

seq_length = x.shape[1]

# The alpha constant for the geometric series is 2
w_seq = [4, 8, 16]
r_seq = [1, 2, 4]

O = dilated_attention_singlehead(x, w_seq, r_seq)

print(O.shape)



torch.Size([32, 32, 512])
