In [1]:
import torch
import torch.nn as nn

In [2]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, query_dim, use_positional_embedding=True):
        """
        embed_dim: Feature dimensionality (C)
        num_heads: Number of attention heads
        query_dim: Number of queries (e.g., 128 * 128 for high-res queries)
        use_positional_embedding: If True, use positional embeddings as queries; else, use learnable query tokens.
        """
        super(CrossAttention, self).__init__()
        self.use_positional_embedding = use_positional_embedding
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

        if use_positional_embedding:
            # Use positional embeddings as queries
            self.query_input = nn.Parameter(torch.randn(1, query_dim, embed_dim))  # (1, 128*128, C)
        else:
            # Use learnable query tokens instead
            self.query_tokens = nn.Parameter(torch.randn(1, query_dim, embed_dim))  # (1, 128*128, C)

    def forward(self, encoder_out):
        """
        encoder_out: (B, 16 * 128, C) - Encoder feature grid.
        Returns: (B, 128 * 128, C) - Updated feature map after cross-attention.
        """
        B = encoder_out.shape[0]

        if self.use_positional_embedding:
            # Expand positional embeddings to match batch size
            queries = self.query_input.expand(B, -1, -1)  # (B, 128*128, C)
        else:
            # Expand learnable query tokens
            queries = self.query_tokens.expand(B, -1, -1)  # (B, 128*128, C)

        # Apply cross-attention: Query attends over the encoder features
        output, _ = self.attn(queries, encoder_out, encoder_out)
        return output

In [3]:
# Example usage
B, H_kv, W_kv, H_q, W_q, C, num_heads = 2, 16, 128, 128, 128, 256, 8
query_dim = H_q * W_q  # 128 * 128

# Encoder feature grid
F = torch.randn(B, H_kv * W_kv, C)  # (B, 16*128, C)

# Option 1: With Positional Embedding
cross_attn_with_pos = CrossAttention(embed_dim=C, num_heads=num_heads, query_dim=query_dim, use_positional_embedding=True)
I_updated_with_pos = cross_attn_with_pos(F)
print("With Positional Embedding:", I_updated_with_pos.shape)  # Expected: (B, 128*128, C)

# Option 2: With Learnable Query Tokens
cross_attn_with_tokens = CrossAttention(embed_dim=C, num_heads=num_heads, query_dim=query_dim, use_positional_embedding=False)
I_updated_with_tokens = cross_attn_with_tokens(F)
print("With Learnable Query Tokens:", I_updated_with_tokens.shape)  # Expected: (B, 128*128, C)

With Positional Embedding: torch.Size([2, 16384, 256])
With Learnable Query Tokens: torch.Size([2, 16384, 256])
