In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelectiveAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(SelectiveAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == embed_dim, "Embedding dimension must be divisible by number of heads"

        # Linear layers for query, key, and value
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Output projection layer
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Softmax scaling factor
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        # Project input to queries, keys, and values
        q = self.q_proj(x).view(batch_size, self.num_heads, seq_len, self.head_dim)
        k = self.k_proj(x).view(batch_size, self.num_heads, seq_len, self.head_dim)
        v = self.v_proj(x).view(batch_size, self.num_heads, seq_len, self.head_dim)

        # Compute scaled dot-product attention scores
        attn_logits = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # Apply softmax to get attention weights
        attn_weights = F.softmax(attn_logits, dim=-1)

        # Compute attention output (without selective masking)
        attn_output = torch.matmul(attn_weights, v)

        # Now, let's introduce the F matrix for selective attention
        # Initialize F as zeros, then we'll update it
        F_matrix = torch.zeros_like(attn_logits)

        # Assume the first attention head decides which tokens to mask
        selection_head = attn_weights[:, 0, :, :]  # Taking the first head for simplicity

        # Apply ReLU to ensure non-negative masking and accumulate the masking effect
        selection_head = F.relu(selection_head)
        F_matrix[:, 0, :, :] = selection_head

        # Accumulate masking over the sequence length (causal masking for future tokens)
        for i in range(1, seq_len):
            F_matrix[:, :, i, :] = F_matrix[:, :, i-1, :] + F.relu(attn_logits[:, :, i, :])

        # Subtract F from the attention logits to mask irrelevant tokens
        masked_attn_logits = attn_logits - F_matrix

        # Recompute attention weights with the selective masking applied
        masked_attn_weights = F.softmax(masked_attn_logits, dim=-1)

        # Compute the final attention output using the masked attention weights
        masked_attn_output = torch.matmul(masked_attn_weights, v)

        # Reshape and project the output back to the original embedding dimension
        masked_attn_output = masked_attn_output.view(batch_size, seq_len, embed_dim)

        # Apply the output projection
        return self.out_proj(masked_attn_output)

# Example usage:
batch_size = 2
seq_len = 5
embed_dim = 64
num_heads = 8

x = torch.rand(batch_size, seq_len, embed_dim)
selective_attention = SelectiveAttention(embed_dim, num_heads)
output = selective_attention(x)

print("Output shape:", output.shape)

Output shape: torch.Size([2, 5, 64])
