In [9]:
import torch
import torch.nn.functional as F

class HollowAttention(torch.nn.Module):
    def __init__(self, num_heads, d_model):
        super(HollowAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        
        self.query_projection = torch.nn.Linear(d_model, d_model)
        self.key_projection = torch.nn.Linear(d_model, d_model)
        self.value_projection = torch.nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask):
        batch_size = query.shape[0]
        query = self.query_projection(query)
        key = self.key_projection(key)
        value = self.value_projection(value)
        
        query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        
        # Apply hollow mask
        hollow_mask = torch.triu(torch.ones(scores.shape[-2:]), diagonal=1)
        hollow_mask = hollow_mask.unsqueeze(0).unsqueeze(1).to(query.device)
        scores.masked_fill_(hollow_mask == 1, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        
        output = torch.matmul(attention_weights, value)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        return output

In [10]:

# Example usage
d_model = 256
num_heads = 8
seq_length = 10
batch_size = 4

# Creating random input tensors
query = torch.randn(batch_size, seq_length, d_model)
key = torch.randn(batch_size, seq_length, d_model)
value = torch.randn(batch_size, seq_length, d_model)

# Creating a random hollow mask
mask = torch.zeros(batch_size, seq_length, seq_length)
mask[0, :, :5] = 1
mask[1, :, :8] = 1
mask[2, :, :3] = 1
mask[3, :, :6] = 1

attention = HollowAttention(num_heads, d_model)
output = attention(query, key, value, mask)
print(output.shape)

torch.Size([4, 10, 256])


In [11]:
import torch
import torch.nn.functional as F

class MultiheadHollowAttention(torch.nn.Module):
    def __init__(self, num_heads, d_model):
        super(MultiheadHollowAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads
        
        self.query_projection = torch.nn.Linear(d_model, d_model)
        self.key_projection = torch.nn.Linear(d_model, d_model)
        self.value_projection = torch.nn.Linear(d_model, d_model)
        
        self.final_projection = torch.nn.Linear(d_model, d_model)
        
    def split_heads(self, tensor, batch_size):
        return tensor.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
    def forward(self, query, key, value, mask):
        batch_size = query.shape[0]
        query = self.query_projection(query)
        key = self.key_projection(key)
        value = self.value_projection(value)
        
        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)
        
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        
        # Apply hollow mask in both directions
        hollow_mask_forward = torch.triu(torch.ones(scores.shape[-2:]), diagonal=1)
        hollow_mask_backward = torch.tril(torch.ones(scores.shape[-2:]), diagonal=-1)
        
        hollow_mask_forward = hollow_mask_forward.unsqueeze(0).unsqueeze(1).to(query.device)
        hollow_mask_backward = hollow_mask_backward.unsqueeze(0).unsqueeze(1).to(query.device)
        
        scores.masked_fill_(hollow_mask_forward == 1, float('-inf'))
        scores.masked_fill_(hollow_mask_backward == 1, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        
        output = torch.matmul(attention_weights, value)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        
        output = self.final_projection(output)
        
        return output

# Example usage
d_model = 256
num_heads = 8
seq_length = 10
batch_size = 4

# Creating random input tensors
query = torch.randn(batch_size, seq_length, d_model)
key = torch.randn(batch_size, seq_length, d_model)
value = torch.randn(batch_size, seq_length, d_model)

# Creating a random hollow mask
mask = torch.zeros(batch_size, seq_length, seq_length)
mask[0, :, :5] = 1
mask[1, :, :8] = 1
mask[2, :, :3] = 1
mask[3, :, :6] = 1

attention = MultiheadHollowAttention(num_heads, d_model)
output = attention(query, key, value, mask)
print(output.shape)


torch.Size([4, 10, 256])
