In [1]:
pip install timm torch torchvision torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, List
import math

class SwinTransformerPatchEmbedding(nn.Module):
    """Simplified Swin Transformer patch embedding for feature extraction"""
    def __init__(self, patch_size: int = 4, embed_dim: int = 96):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: [B, C, H, W]
        x = self.proj(x)  # [B, embed_dim, H//patch_size, W//patch_size]
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, H*W, C]
        return x, (H, W)

class GridGraphFeature(nn.Module):
    """
    Grid-graph feature extraction based on the paper description
    """
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 4,
                 embed_dim: int = 96,
                 gcn_hidden_dim: int = 128):
        super().__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.gcn_hidden_dim = gcn_hidden_dim

        # Calculate number of patches
        self.num_patches_per_side = image_size // patch_size  # M = k = 56 for 224/4
        self.num_patches = self.num_patches_per_side ** 2     # M*k patches total

        # Swin Transformer for patch feature extraction
        self.swin_embedding = SwinTransformerPatchEmbedding(patch_size, embed_dim)

        # Graph Convolutional Layer
        self.gcn = GraphConvolutionalLayer(embed_dim, gcn_hidden_dim)

    def extract_patch_features(self, image: torch.Tensor) -> torch.Tensor:
        """
        Extract patch features using Swin Transformer
        Args:
            image: [B, 3, 224, 224] input image
        Returns:
            X^V: [B, M*k, embed_dim] patch feature matrix
        """
        # Split image into M x k patches and extract features
        patch_features, (H, W) = self.swin_embedding(image)  # [B, M*k, embed_dim]

        return patch_features

    def compute_similarity_matrix(self, patch_features: torch.Tensor) -> torch.Tensor:
        """
        Compute similarity matrix between patches
        Args:
            patch_features: [B, M*k, embed_dim] patch features
        Returns:
            A^V: [B, M*k, M*k] adjacency matrix
        """
        B, N, D = patch_features.shape

        # Normalize patch features for cosine similarity
        patch_features_norm = F.normalize(patch_features, p=2, dim=-1)

        # Compute cosine similarity matrix
        similarity_matrix = torch.bmm(patch_features_norm, patch_features_norm.transpose(-2, -1))

        # Create adjacency matrix based on similarity
        # A^V_ij = 1 if |i-j| = 1 or |i-j| = k (adjacent patches)
        #        = sim(i,j) otherwise
        adjacency_matrix = torch.zeros_like(similarity_matrix)

        for i in range(N):
            for j in range(N):
                # Convert 1D indices to 2D grid coordinates
                i_row, i_col = i // self.num_patches_per_side, i % self.num_patches_per_side
                j_row, j_col = j // self.num_patches_per_side, j % self.num_patches_per_side

                # Check if patches are adjacent (horizontally or vertically)
                if abs(i_row - j_row) + abs(i_col - j_col) == 1:
                    adjacency_matrix[:, i, j] = 1.0
                elif i == j:
                    adjacency_matrix[:, i, j] = 1.0
                else:
                    adjacency_matrix[:, i, j] = similarity_matrix[:, i, j]
        return adjacency_matrix

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of grid-graph feature extraction
        Args:
            image: [B, 3, 224, 224] input image
        Returns:
            V^(0): [B, M*k, hidden_dim] initial context grid representation
        """
        # Step 1: Extract patch features using Swin Transformer
        patch_features = self.extract_patch_features(image)  # X^V

        # Step 2: Compute similarity-based adjacency matrix
        adjacency_matrix = self.compute_similarity_matrix(patch_features)  # A^V

        # Step 3: Apply Graph Convolutional Layer
        grid_representation = self.gcn(patch_features, adjacency_matrix)  # V^(0)

        return grid_representation

class GraphConvolutionalLayer(nn.Module):
    """
    Graph Convolutional Layer implementation
    V^(0) = σ(Ã^T VW^T)
    """
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Weight matrix W
        self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, node_features: torch.Tensor, adjacency_matrix: torch.Tensor) -> torch.Tensor:
        """
        Args:
            node_features: [B, N, input_dim] node feature matrix V
            adjacency_matrix: [B, N, N] adjacency matrix A^V
        Returns:
            output: [B, N, output_dim] output features V^(0)
        """
        # Compute degree matrix D
        degree_matrix = torch.sum(adjacency_matrix, dim=-1, keepdim=True)  # [B, N, 1]
        degree_matrix = torch.clamp(degree_matrix, min=1.0)  # Avoid division by zero

        # Normalize adjacency matrix: Ã = D^(-1/2) A^V D^(-1/2)
        degree_inv_sqrt = torch.pow(degree_matrix, -0.5)
        normalized_adj = degree_inv_sqrt * adjacency_matrix * degree_inv_sqrt.transpose(-2, -1)

        # Apply weight transformation: VW^T
        transformed_features = torch.matmul(node_features, self.weight)  # [B, N, output_dim]

        # Graph convolution: Ã^T VW^T
        output = torch.bmm(normalized_adj.transpose(-2, -1), transformed_features)

        # Apply activation function (ReLU)
        output = F.relu(output)

        return output

# Example usage and testing
def test_grid_graph_feature():
    """Test the Grid-Graph Feature implementation"""

    # Create model
    model = GridGraphFeature(
        image_size=224,
        patch_size=4,
        embed_dim=96,
        gcn_hidden_dim=128
    )

    # Create dummy input
    batch_size = 2
    dummy_image = torch.randn(batch_size, 3, 224, 224)

    print(f"Input image shape: {dummy_image.shape}")
    print(f"Number of patches: {model.num_patches}")
    print(f"Patches per side: {model.num_patches_per_side}")

    # Forward pass
    with torch.no_grad():
        grid_representation = model(dummy_image)

    print(f"Output grid representation shape: {grid_representation.shape}")
    print(f"Expected shape: [{batch_size}, {model.num_patches}, {model.gcn_hidden_dim}]")

    # Test individual components
    print("\n--- Testing individual components ---")

    # Test patch feature extraction
    patch_features = model.extract_patch_features(dummy_image)
    print(f"Patch features shape: {patch_features.shape}")

    # Test similarity matrix computation
    similarity_matrix = model.compute_similarity_matrix(patch_features)
    print(f"Similarity matrix shape: {similarity_matrix.shape}")

    # Verify adjacency matrix properties
    print(f"Adjacency matrix diagonal sum: {torch.diagonal(similarity_matrix, dim1=-2, dim2=-1).sum()}")
    print(f"Adjacency matrix range: [{similarity_matrix.min():.3f}, {similarity_matrix.max():.3f}]")

if __name__ == "__main__":
    test_grid_graph_feature()

Input image shape: torch.Size([2, 3, 224, 224])
Number of patches: 3136
Patches per side: 56
Output grid representation shape: torch.Size([2, 3136, 128])
Expected shape: [2, 3136, 128]

--- Testing individual components ---
Patch features shape: torch.Size([2, 3136, 96])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Tuple, List, Dict
import torchvision.ops as ops

class FasterRCNN(nn.Module):
    """Simplified Faster R-CNN for object region detection"""
    def __init__(self, backbone_dim: int = 512, num_classes: int = 80):
        super().__init__()
        self.backbone_dim = backbone_dim
        self.num_classes = num_classes

        # Simplified backbone (normally ResNet)
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, backbone_dim, 3, padding=1),
            nn.ReLU(inplace=True)
        )

        # ROI pooling output dimension
        self.roi_pool_size = 7
        self.roi_pooling = ops.RoIPool(output_size=(self.roi_pool_size, self.roi_pool_size), spatial_scale=1/16)

        # Feature extraction after ROI pooling
        self.roi_head = nn.Sequential(
            nn.Linear(backbone_dim * self.roi_pool_size * self.roi_pool_size, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True)
        )

    def forward(self, images: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor:
        """
        Extract features for given region boxes
        Args:
            images: [B, 3, H, W] input images
            boxes: [B, P, 4] region boxes in format [x1, y1, x2, y2]
        Returns:
            region_features: [B, P, 512] features for each region
        """
        # Extract backbone features
        backbone_features = self.backbone(images)  # [B, backbone_dim, H', W']

        B, P, _ = boxes.shape
        region_features = []

        for b in range(B):
            # Prepare boxes for ROI pooling (add batch index)
            batch_boxes = torch.cat([
                torch.full((P, 1), b, dtype=boxes.dtype, device=boxes.device),
                boxes[b]
            ], dim=1)  # [P, 5] format: [batch_idx, x1, y1, x2, y2]

            # ROI pooling
            pooled_features = self.roi_pooling(backbone_features[b:b+1], batch_boxes)  # [P, backbone_dim, 7, 7]

            # Flatten and process through ROI head
            pooled_flat = pooled_features.view(P, -1)  # [P, backbone_dim * 7 * 7]
            roi_features = self.roi_head(pooled_flat)  # [P, 512]

            region_features.append(roi_features)

        return torch.stack(region_features, dim=0)  # [B, P, 512]

class MultiHeadAttention(nn.Module):
    """Multi-Head Attention mechanism for region features"""
    def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # Output projection
        self.w_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        Args:
            query, key, value: [B, P, d_model] region features
        Returns:
            output: [B, P, d_model] attended features
        """
        B, P, d_model = query.shape

        # Linear projections
        Q = self.w_q(query)  # [B, P, d_model]
        K = self.w_k(key)    # [B, P, d_model]
        V = self.w_v(value)  # [B, P, d_model]

        # Reshape for multi-head attention
        Q = Q.view(B, P, self.num_heads, self.d_k).transpose(1, 2)  # [B, num_heads, P, d_k]
        K = K.view(B, P, self.num_heads, self.d_k).transpose(1, 2)  # [B, num_heads, P, d_k]
        V = V.view(B, P, self.num_heads, self.d_k).transpose(1, 2)  # [B, num_heads, P, d_k]

        # Scaled dot-product attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # [B, num_heads, P, P]
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention to values
        attended_values = torch.matmul(attention_weights, V)  # [B, num_heads, P, d_k]

        # Concatenate heads
        attended_values = attended_values.transpose(1, 2).contiguous().view(B, P, d_model)

        # Final linear projection
        output = self.w_o(attended_values)

        return output

class SelfAttention(nn.Module):
    """Self-Attention mechanism as described in the paper"""
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        self.scale = math.sqrt(d_model)

    def forward(self, R: torch.Tensor) -> torch.Tensor:
        """
        Compute QK^T Self-Attention
        Args:
            R: [B, P, d] region features where Q = K = V = R
        Returns:
            attention_output: [B, P, d] attended features
        """
        # Q = K = V = R
        Q = K = V = R  # [B, P, d]

        # Compute attention scores: QK^T / sqrt(d)
        attention_scores = torch.bmm(Q, K.transpose(-2, -1)) / self.scale  # [B, P, P]

        # Apply softmax
        attention_weights = F.softmax(attention_scores, dim=-1)  # [B, P, P]

        # Apply attention to values
        attention_output = torch.bmm(attention_weights, V)  # [B, P, d]

        return attention_output

class RegionGraphFeature(nn.Module):
    """
    Region Graph Feature extraction using Faster R-CNN and Multi-Head Attention
    """
    def __init__(self,
                 num_regions: int = 100,
                 region_feature_dim: int = 512,
                 num_attention_heads: int = 8,
                 output_dim: int = 512):
        super().__init__()

        self.num_regions = num_regions
        self.region_feature_dim = region_feature_dim
        self.num_attention_heads = num_attention_heads
        self.output_dim = output_dim

        # Faster R-CNN for region detection and feature extraction
        self.faster_rcnn = FasterRCNN(backbone_dim=512, num_classes=80)

        # Self-Attention mechanism
        self.self_attention = SelfAttention(region_feature_dim)

        # Multi-Head Attention mechanism
        self.multi_head_attention = MultiHeadAttention(
            d_model=region_feature_dim,
            num_heads=num_attention_heads
        )

        # Final projection layer
        self.output_projection = nn.Linear(region_feature_dim, output_dim)

    def generate_dummy_boxes(self, batch_size: int, num_boxes: int, image_size: Tuple[int, int]) -> torch.Tensor:
        """
        Generate dummy bounding boxes for testing
        In practice, these would come from Faster R-CNN's RPN
        """
        H, W = image_size
        boxes = []

        for _ in range(batch_size):
            batch_boxes = []
            for _ in range(num_boxes):
                # Random box coordinates
                x1 = torch.randint(0, W//2, (1,)).float()
                y1 = torch.randint(0, H//2, (1,)).float()
                x2 = x1 + torch.randint(W//4, W//2, (1,)).float()
                y2 = y1 + torch.randint(H//4, H//2, (1,)).float()

                # Ensure boxes are within image bounds
                x2 = torch.clamp(x2, max=W-1)
                y2 = torch.clamp(y2, max=H-1)

                batch_boxes.append(torch.stack([x1, y1, x2, y2]))

            boxes.append(torch.stack(batch_boxes).squeeze())

        return torch.stack(boxes)

    def forward(self, images: torch.Tensor, region_boxes: torch.Tensor = None) -> torch.Tensor:
        """
        Forward pass for region graph feature extraction
        Args:
            images: [B, 3, H, W] input images
            region_boxes: [B, P, 4] region boxes (optional, will generate dummy if None)
        Returns:
            R^(0): [B, P, output_dim] initial region representation
        """
        B, C, H, W = images.shape

        # Generate dummy boxes if not provided (in practice, use Faster R-CNN's RPN)
        if region_boxes is None:
            region_boxes = self.generate_dummy_boxes(B, self.num_regions, (H, W))
            region_boxes = region_boxes.to(images.device)

        # Step 1: Extract region features using Faster R-CNN
        region_features = self.faster_rcnn(images, region_boxes)  # R: [B, P, d]

        # Step 2: Apply Self-Attention mechanism
        # QK^T Attention where Q = K = V = R
        sa_output = self.self_attention(region_features)  # [B, P, d]

        # Step 3: Apply Multi-Head Attention for more comprehensive interaction
        mha_output = self.multi_head_attention(sa_output, sa_output, sa_output)  # [B, P, d]

        # Step 4: Combine with residual connection
        combined_features = region_features + mha_output  # [B, P, d]

        # Step 5: Final projection to get R^(0)
        initial_region_representation = self.output_projection(combined_features)  # [B, P, output_dim]

        return initial_region_representation

    def get_region_vectors(self, region_representation: torch.Tensor) -> List[torch.Tensor]:
        """
        Extract individual region vectors {r_i}
        Args:
            region_representation: [B, P, d] region features
        Returns:
            List of region vectors for each batch
        """
        B, P, d = region_representation.shape
        region_vectors = []

        for b in range(B):
            batch_vectors = [region_representation[b, p, :] for p in range(P)]
            region_vectors.append(batch_vectors)

        return region_vectors

# Example usage and testing
def test_region_graph_feature():
    """Test the Region Graph Feature implementation"""

    # Create model
    model = RegionGraphFeature(
        num_regions=50,
        region_feature_dim=512,
        num_attention_heads=8,
        output_dim=512
    )

    # Create dummy input
    batch_size = 2
    dummy_images = torch.randn(batch_size, 3, 224, 224)

    print(f"Input images shape: {dummy_images.shape}")
    print(f"Number of regions: {model.num_regions}")
    print(f"Region feature dimension: {model.region_feature_dim}")

    # Forward pass
    with torch.no_grad():
        region_representation = model(dummy_images)

    print(f"Output region representation shape: {region_representation.shape}")
    print(f"Expected shape: [{batch_size}, {model.num_regions}, {model.output_dim}]")

    # Test with custom boxes
    print("\n--- Testing with custom region boxes ---")
    custom_boxes = torch.tensor([
        [[10, 10, 50, 50], [60, 60, 100, 100], [120, 30, 180, 90]],
        [[20, 20, 80, 80], [90, 10, 150, 70], [30, 100, 90, 160]]
    ], dtype=torch.float32)

    with torch.no_grad():
        custom_region_representation = model(dummy_images, custom_boxes)

    print(f"Custom region representation shape: {custom_region_representation.shape}")

    # Test individual components
    print("\n--- Testing individual components ---")

    # Test Faster R-CNN feature extraction
    dummy_boxes = model.generate_dummy_boxes(batch_size, 10, (224, 224))
    region_features = model.faster_rcnn(dummy_images, dummy_boxes)
    print(f"Faster R-CNN features shape: {region_features.shape}")

    # Test Self-Attention
    sa_output = model.self_attention(region_features)
    print(f"Self-Attention output shape: {sa_output.shape}")

    # Test Multi-Head Attention
    mha_output = model.multi_head_attention(region_features, region_features, region_features)
    print(f"Multi-Head Attention output shape: {mha_output.shape}")

    # Extract individual region vectors
    region_vectors = model.get_region_vectors(region_representation)
    print(f"Number of region vector batches: {len(region_vectors)}")
    print(f"Number of vectors per batch: {len(region_vectors[0])}")
    print(f"Each region vector shape: {region_vectors[0][0].shape}")

if __name__ == "__main__":
    test_region_graph_feature()

Input images shape: torch.Size([2, 3, 224, 224])
Number of regions: 50
Region feature dimension: 512
Output region representation shape: torch.Size([2, 50, 512])
Expected shape: [2, 50, 512]

--- Testing with custom region boxes ---
Custom region representation shape: torch.Size([2, 3, 512])

--- Testing individual components ---


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
import numpy as np
import math
from typing import List, Dict, Tuple, Optional
import spacy
import networkx as nx

class BERTEmbedding(nn.Module):
    """BERT-based text embedding for semantic features"""
    def __init__(self, model_name: str = 'bert-base-uncased', embedding_dim: int = 768):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert_model = BertModel.from_pretrained(model_name)

        # Freeze BERT parameters for efficiency (optional)
        for param in self.bert_model.parameters():
            param.requires_grad = False

    def forward(self, text_sequences: List[str]) -> torch.Tensor:
        """
        Convert text sequences to BERT embeddings
        Args:
            text_sequences: List of text strings
        Returns:
            embeddings: [N, embedding_dim] where N is number of sequences
        """
        # Tokenize texts
        encoded = self.tokenizer(
            text_sequences,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )

        # Get BERT embeddings
        with torch.no_grad():
            outputs = self.bert_model(**encoded)
            # Use [CLS] token embedding as sentence representation
            embeddings = outputs.last_hidden_state[:, 0, :]  # [N, embedding_dim]

        return embeddings

class DependencyParser:
    """Dependency parsing for constructing semantic graphs"""
    def __init__(self):
        # Load spaCy model for dependency parsing
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except OSError:
            print("Warning: spaCy English model not found. Using dummy parser.")
            self.nlp = None

    def parse_text(self, text: str) -> List[Dict]:
        """
        Parse text and extract dependency relationships
        Args:
            text: Input text string
        Returns:
            dependencies: List of dependency relations
        """
        if self.nlp is None:
            # Dummy dependencies for testing
            words = text.split()
            return [{"head": i, "child": i+1, "relation": "dummy"}
                   for i in range(len(words)-1)]

        doc = self.nlp(text)
        dependencies = []

        for token in doc:
            if token.head != token:  # Skip root
                dependencies.append({
                    "head": token.head.i,
                    "child": token.i,
                    "relation": token.dep_,
                    "head_text": token.head.text,
                    "child_text": token.text
                })

        return dependencies

    def extract_words(self, text: str) -> List[str]:
        """Extract individual words from text"""
        if self.nlp is None:
            return text.split()

        doc = self.nlp(text)
        return [token.text for token in doc if not token.is_punct]

class SemanticGraphConstructor:
    """Construct semantic graphs from dependency trees"""
    def __init__(self):
        self.dependency_parser = DependencyParser()

    def construct_adjacency_matrix(self, text: str) -> Tuple[torch.Tensor, List[str]]:
        """
        Construct adjacency matrix from dependency tree
        Args:
            text: Input text string
        Returns:
            adjacency_matrix: [N, N] adjacency matrix
            words: List of words corresponding to matrix indices
        """
        # Extract words and dependencies
        words = self.dependency_parser.extract_words(text)
        dependencies = self.dependency_parser.parse_text(text)

        N = len(words)
        if N == 0:
            return torch.zeros(1, 1), [""]

        # Initialize adjacency matrix
        adjacency_matrix = torch.zeros(N, N)

        # Fill adjacency matrix based on dependencies
        for dep in dependencies:
            head_idx = dep["head"]
            child_idx = dep["child"]

            # Ensure indices are within bounds
            if 0 <= head_idx < N and 0 <= child_idx < N:
                # Undirected graph: set both directions
                adjacency_matrix[head_idx, child_idx] = 1.0
                adjacency_matrix[child_idx, head_idx] = 1.0

        # Add self-connections (diagonal)
        for i in range(N):
            adjacency_matrix[i, i] = 1.0

        return adjacency_matrix, words

    def correlation_function(self, wi: str, wj: str) -> float:
        """
        Compute correlation between two words D(wi, wj)
        This is a simplified implementation - in practice, could use:
        - Word embeddings similarity
        - Co-occurrence statistics
        - Semantic similarity measures
        """
        # Simple correlation based on string similarity and length
        if wi == wj:
            return 1.0

        # Jaccard similarity of character sets
        set_i = set(wi.lower())
        set_j = set(wj.lower())

        if len(set_i.union(set_j)) == 0:
            return 0.0

        correlation = len(set_i.intersection(set_j)) / len(set_i.union(set_j))
        return correlation

class SemanticGraphConvolutionalLayer(nn.Module):
    """
    Semantic Graph Convolutional Layer
    S^(0) = σ(Ã^S σ(Ã^S SW₁^S) W₂^S)
    """
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        # Weight matrices
        self.W1 = nn.Parameter(torch.FloatTensor(input_dim, hidden_dim))
        self.W2 = nn.Parameter(torch.FloatTensor(hidden_dim, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        stdv1 = 1. / math.sqrt(self.W1.size(1))
        self.W1.data.uniform_(-stdv1, stdv1)

        stdv2 = 1. / math.sqrt(self.W2.size(1))
        self.W2.data.uniform_(-stdv2, stdv2)

    def normalize_adjacency_matrix(self, adjacency_matrix: torch.Tensor) -> torch.Tensor:
        """
        Normalize adjacency matrix: Ã = D^(-1/2) A D^(-1/2)
        """
        # Compute degree matrix
        degree = torch.sum(adjacency_matrix, dim=-1, keepdim=True)  # [N, 1]
        degree = torch.clamp(degree, min=1.0)  # Avoid division by zero

        # D^(-1/2)
        degree_inv_sqrt = torch.pow(degree, -0.5)

        # Normalize: D^(-1/2) A D^(-1/2)
        normalized_adj = degree_inv_sqrt * adjacency_matrix * degree_inv_sqrt.transpose(-2, -1)

        return normalized_adj

    def forward(self, node_features: torch.Tensor, adjacency_matrix: torch.Tensor) -> torch.Tensor:
        """
        Args:
            node_features: [N, input_dim] node features S
            adjacency_matrix: [N, N] adjacency matrix A^S
        Returns:
            output: [N, output_dim] output features S^(0)
        """
        # Normalize adjacency matrix
        normalized_adj = self.normalize_adjacency_matrix(adjacency_matrix)  # Ã^S

        # First GCN layer: Ã^S S W₁^S
        h1 = torch.matmul(node_features, self.W1)  # [N, hidden_dim]
        h1 = torch.matmul(normalized_adj, h1)      # [N, hidden_dim]
        h1 = F.relu(h1)  # σ(Ã^S S W₁^S)

        # Second GCN layer: Ã^S σ(Ã^S S W₁^S) W₂^S
        h2 = torch.matmul(h1, self.W2)             # [N, output_dim]
        output = torch.matmul(normalized_adj, h2)   # [N, output_dim]
        output = F.relu(output)  # σ(Ã^S σ(Ã^S S W₁^S) W₂^S)

        return output

class SemanticGraphFeature(nn.Module):
    """
    Complete Semantic Graph Feature extraction pipeline
    """
    def __init__(self,
                 bert_model_name: str = 'bert-base-uncased',
                 bert_dim: int = 768,
                 gcn_hidden_dim: int = 512,
                 output_dim: int = 256):
        super().__init__()

        self.bert_dim = bert_dim
        self.gcn_hidden_dim = gcn_hidden_dim
        self.output_dim = output_dim

        # BERT embedding for word-level features
        self.bert_embedding = BERTEmbedding(bert_model_name, bert_dim)

        # Semantic graph constructor
        self.graph_constructor = SemanticGraphConstructor()

        # Semantic Graph Convolutional Layer
        self.semantic_gcn = SemanticGraphConvolutionalLayer(
            input_dim=bert_dim,
            hidden_dim=gcn_hidden_dim,
            output_dim=output_dim
        )

    def process_single_text(self, text: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Process a single text to extract semantic graph features
        Args:
            text: Input text string
        Returns:
            semantic_features: [N, output_dim] semantic node features
            adjacency_matrix: [N, N] semantic adjacency matrix
        """
        # Step 1: Construct semantic graph
        adjacency_matrix, words = self.graph_constructor.construct_adjacency_matrix(text)

        if len(words) == 0 or words == [""]:
            # Handle empty text
            return torch.zeros(1, self.output_dim), torch.zeros(1, 1)

        # Step 2: Get BERT embeddings for words
        word_embeddings = self.bert_embedding(words)  # [N, bert_dim]

        # Step 3: Apply Semantic GCN
        semantic_features = self.semantic_gcn(word_embeddings, adjacency_matrix)  # [N, output_dim]

        return semantic_features, adjacency_matrix

    def forward(self, text_list: List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        Process multiple texts
        Args:
            text_list: List of text strings
        Returns:
            semantic_features_list: List of [Ni, output_dim] tensors
            adjacency_matrices_list: List of [Ni, Ni] tensors
        """
        semantic_features_list = []
        adjacency_matrices_list = []

        for text in text_list:
            semantic_features, adjacency_matrix = self.process_single_text(text)
            semantic_features_list.append(semantic_features)
            adjacency_matrices_list.append(adjacency_matrix)

        return semantic_features_list, adjacency_matrices_list

    def get_semantic_vectors(self, semantic_features_list: List[torch.Tensor]) -> List[List[torch.Tensor]]:
        """
        Extract individual semantic vectors {s_i}
        Args:
            semantic_features_list: List of [Ni, output_dim] tensors
        Returns:
            List of lists containing individual semantic vectors
        """
        semantic_vectors = []

        for semantic_features in semantic_features_list:
            N, d = semantic_features.shape
            text_vectors = [semantic_features[i, :] for i in range(N)]
            semantic_vectors.append(text_vectors)

        return semantic_vectors

# Example usage and testing
def test_semantic_graph_feature():
    """Test the Semantic Graph Feature implementation"""

    print("Testing Semantic Graph Feature Implementation")
    print("=" * 50)

    # Create model
    model = SemanticGraphFeature(
        bert_model_name='bert-base-uncased',
        bert_dim=768,
        gcn_hidden_dim=512,
        output_dim=256
    )

    # Test texts (image captions)
    test_texts = [
        "A cat sitting on a wooden table",
        "Two dogs playing in the park",
        "Beautiful sunset over the ocean",
        "Person riding bicycle on street"
    ]

    print(f"Test texts: {len(test_texts)} captions")
    for i, text in enumerate(test_texts):
        print(f"  {i+1}. {text}")

    # Process texts
    semantic_features_list, adjacency_matrices_list = model(test_texts)

    print(f"\nResults:")
    print(f"Number of processed texts: {len(semantic_features_list)}")

    for i, (features, adj_matrix) in enumerate(zip(semantic_features_list, adjacency_matrices_list)):
        print(f"\nText {i+1}: '{test_texts[i]}'")
        print(f"  Semantic features shape: {features.shape}")
        print(f"  Adjacency matrix shape: {adj_matrix.shape}")
        print(f"  Number of words/nodes: {features.shape[0]}")
        print(f"  Feature dimension: {features.shape[1]}")
        print(f"  Graph density: {(adj_matrix.sum() - adj_matrix.trace()) / (adj_matrix.numel() - adj_matrix.shape[0]):.3f}")

    # Test individual components
    print(f"\n" + "="*50)
    print("Testing Individual Components")
    print("="*50)

    # Test dependency parsing
    test_text = "A cat sitting on a wooden table"
    dependencies = model.graph_constructor.dependency_parser.parse_text(test_text)
    words = model.graph_constructor.dependency_parser.extract_words(test_text)

    print(f"\nDependency parsing for: '{test_text}'")
    print(f"Words: {words}")
    print(f"Dependencies: {len(dependencies)}")
    for dep in dependencies[:3]:  # Show first 3
        print(f"  {dep}")

    # Test adjacency matrix construction
    adj_matrix, extracted_words = model.graph_constructor.construct_adjacency_matrix(test_text)
    print(f"\nAdjacency matrix shape: {adj_matrix.shape}")
    print(f"Extracted words: {extracted_words}")
    print(f"Adjacency matrix:\n{adj_matrix}")

    # Test BERT embeddings
    word_embeddings = model.bert_embedding(words[:3])  # Test first 3 words
    print(f"\nBERT embeddings shape for 3 words: {word_embeddings.shape}")
    print(f"Embedding dimension: {word_embeddings.shape[1]}")

    # Test semantic vectors extraction
    semantic_vectors = model.get_semantic_vectors(semantic_features_list)
    print(f"\nSemantic vectors:")
    print(f"Number of texts: {len(semantic_vectors)}")
    if len(semantic_vectors) > 0:
        print(f"Vectors in first text: {len(semantic_vectors[0])}")
        if len(semantic_vectors[0]) > 0:
            print(f"Each vector shape: {semantic_vectors[0][0].shape}")

if __name__ == "__main__":
    test_semantic_graph_feature()

2025-05-30 03:04:31.419242: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748574271.616816     113 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748574271.671914     113 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Testing Semantic Graph Feature Implementation


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Test texts: 4 captions
  1. A cat sitting on a wooden table
  2. Two dogs playing in the park
  3. Beautiful sunset over the ocean
  4. Person riding bicycle on street

Results:
Number of processed texts: 4

Text 1: 'A cat sitting on a wooden table'
  Semantic features shape: torch.Size([7, 256])
  Adjacency matrix shape: torch.Size([7, 7])
  Number of words/nodes: 7
  Feature dimension: 256
  Graph density: 0.286

Text 2: 'Two dogs playing in the park'
  Semantic features shape: torch.Size([6, 256])
  Adjacency matrix shape: torch.Size([6, 6])
  Number of words/nodes: 6
  Feature dimension: 256
  Graph density: 0.333

Text 3: 'Beautiful sunset over the ocean'
  Semantic features shape: torch.Size([5, 256])
  Adjacency matrix shape: torch.Size([5, 5])
  Number of words/nodes: 5
  Feature dimension: 256
  Graph density: 0.400

Text 4: 'Person riding bicycle on street'
  Semantic features shape: torch.Size([5, 256])
  Adjacency matrix shape: torch.Size([5, 5])
  Number of words/nodes: 5


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Tuple, List, Dict, Optional

class RegionAggregationFeature(nn.Module):
    """
    Region Aggregation Feature implementation
    Aggregates grid features into representative regional centers
    """
    def __init__(self, 
                 grid_feature_dim: int = 128,
                 num_regions: int = 10,
                 num_clusters: int = 5,
                 temperature: float = 1.0):
        super().__init__()
        
        self.grid_feature_dim = grid_feature_dim
        self.num_regions = num_regions
        self.num_clusters = num_clusters
        self.temperature = temperature
        
        # Learnable cluster centers r_i^(0)
        self.cluster_centers = nn.Parameter(
            torch.randn(num_clusters, grid_feature_dim)
        )
        
        # Adjustable parameters b_r and b_i
        self.b_r = nn.Parameter(torch.zeros(1))
        self.b_i = nn.Parameter(torch.zeros(1))
        
        # Learnable parameter r̃_i for weighted summing
        self.r_tilde = nn.Parameter(torch.randn(grid_feature_dim))
        
        self.reset_parameters()
    
    def reset_parameters(self):
        """Initialize parameters"""
        # Initialize cluster centers with Xavier uniform
        nn.init.xavier_uniform_(self.cluster_centers)
        
        # Initialize biases to zero
        nn.init.zeros_(self.b_r)
        nn.init.zeros_(self.b_i)
        
        # Initialize r̃_i with small random values
        nn.init.normal_(self.r_tilde, mean=0.0, std=0.1)
    
    def compute_correlation_weights(self, 
                                   grid_features: torch.Tensor, 
                                   cluster_centers: torch.Tensor) -> torch.Tensor:
        """
        Compute correlation weights r_ji^(0) using softmax
        Args:
            grid_features: [B, M, grid_feature_dim] grid features v_j^(0)
            cluster_centers: [C, grid_feature_dim] cluster centers r_i^(0)
        Returns:
            correlation_weights: [B, M, C] correlation weights r_ji^(0)
        """
        B, M, D = grid_features.shape
        C = cluster_centers.shape[0]
        
        # Expand dimensions for broadcasting
        # grid_features: [B, M, 1, D]
        # cluster_centers: [1, 1, C, D]
        grid_expanded = grid_features.unsqueeze(2)  # [B, M, 1, D]
        centers_expanded = cluster_centers.unsqueeze(0).unsqueeze(0)  # [1, 1, C, D]
        
        # Compute dot product: v_j^(0) · r_i^(0)
        dot_products = torch.sum(grid_expanded * centers_expanded, dim=-1)  # [B, M, C]
        
        # Add adjustable parameters
        # exp(v_j^(0) · r_i^(0) + b_r) for numerator when computing r_ji^(0)
        numerator_logits = dot_products + self.b_r  # [B, M, C]
        
        # Compute softmax over cluster dimension
        correlation_weights = F.softmax(numerator_logits / self.temperature, dim=-1)  # [B, M, C]
        
        return correlation_weights
    
    def compute_regional_features(self, 
                                  grid_features: torch.Tensor, 
                                  correlation_weights: torch.Tensor) -> torch.Tensor:
        """
        Compute regional features r_i^(1) using weighted summing
        Args:
            grid_features: [B, M, D] grid features v_j^(0)
            correlation_weights: [B, M, C] correlation weights r_ji^(0)
        Returns:
            regional_features: [B, C, D] regional features r_i^(1)
        """
        B, M, D = grid_features.shape
        C = correlation_weights.shape[-1]
        
        # Weighted summing: Σ(j=1 to M) r_ji^(0) * (v_j^(0) - r̃_i)
        # Expand r̃_i to match dimensions
        r_tilde_expanded = self.r_tilde.unsqueeze(0).unsqueeze(0).unsqueeze(0)  # [1, 1, 1, D]
        
        # Compute (v_j^(0) - r̃_i)
        centered_features = grid_features.unsqueeze(2) - r_tilde_expanded  # [B, M, 1, D] - [1, 1, 1, D] = [B, M, 1, D]
        
        # Expand correlation weights for multiplication
        weights_expanded = correlation_weights.unsqueeze(-1)  # [B, M, C, 1]
        
        # Weighted sum: r_ji^(0) * (v_j^(0) - r̃_i)
        weighted_features = weights_expanded * centered_features.unsqueeze(2)  # [B, M, C, D]
        
        # Sum over grid positions
        summed_features = torch.sum(weighted_features, dim=1)  # [B, C, D]
        
        # Apply L2 normalization
        regional_features = F.normalize(summed_features, p=2, dim=-1)  # [B, C, D]
        
        return regional_features
    
    def forward(self, grid_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass for region aggregation
        Args:
            grid_features: [B, M, D] grid features from Section 3.1.1 (V^(0))
        Returns:
            regional_features: [B, C, D] aggregated regional features R^(1)
            correlation_weights: [B, M, C] correlation weights for analysis
        """
        # Step 1: Compute correlation weights r_ji^(0)
        correlation_weights = self.compute_correlation_weights(grid_features, self.cluster_centers)
        
        # Step 2: Compute regional features r_i^(1)
        regional_features = self.compute_regional_features(grid_features, correlation_weights)
        
        return regional_features, correlation_weights
    
    def get_cluster_assignments(self, correlation_weights: torch.Tensor) -> torch.Tensor:
        """
        Get hard cluster assignments for each grid position
        Args:
            correlation_weights: [B, M, C] soft correlation weights
        Returns:
            assignments: [B, M] cluster assignments (0 to C-1)
        """
        return torch.argmax(correlation_weights, dim=-1)
    
    def compute_clustering_loss(self, 
                                grid_features: torch.Tensor, 
                                correlation_weights: torch.Tensor) -> torch.Tensor:
        """
        Compute clustering loss to encourage meaningful clustering
        Args:
            grid_features: [B, M, D] grid features
            correlation_weights: [B, M, C] correlation weights
        Returns:
            loss: scalar clustering loss
        """
        B, M, D = grid_features.shape
        C = self.num_clusters
        
        # Compute cluster centers from weighted average
        weights_sum = torch.sum(correlation_weights, dim=1, keepdim=True)  # [B, 1, C]
        weights_sum = torch.clamp(weights_sum, min=1e-8)  # Avoid division by zero
        
        # Weighted average of features for each cluster
        weighted_features = torch.sum(
            grid_features.unsqueeze(2) * correlation_weights.unsqueeze(-1),
            dim=1
        )  # [B, C, D]
        computed_centers = weighted_features / weights_sum.unsqueeze(-1)  # [B, C, D]
        
        # Compute intra-cluster variance
        grid_expanded = grid_features.unsqueeze(2)  # [B, M, 1, D]
        centers_expanded = computed_centers.unsqueeze(1)  # [B, 1, C, D]
        
        # Distance from each point to each cluster center
        distances = torch.sum((grid_expanded - centers_expanded) ** 2, dim=-1)  # [B, M, C]
        
        # Weighted intra-cluster variance
        weighted_variance = torch.sum(correlation_weights * distances) / (B * M)
        
        return weighted_variance

class RegionAggregationModule(nn.Module):
    """
    Complete Region Aggregation Module with multiple aggregation methods
    """
    def __init__(self, 
                 grid_feature_dim: int = 128,
                 num_regions: int = 10,
                 num_clusters: int = 5,
                 aggregation_method: str = 'attention'):
        super().__init__()
        
        self.grid_feature_dim = grid_feature_dim
        self.num_regions = num_regions
        self.num_clusters = num_clusters
        self.aggregation_method = aggregation_method
        
        # Main region aggregation component
        self.region_aggregator = RegionAggregationFeature(
            grid_feature_dim=grid_feature_dim,
            num_regions=num_regions,
            num_clusters=num_clusters
        )
        
        # Alternative aggregation methods
        if aggregation_method == 'attention':
            self.attention_aggregator = AttentionAggregator(grid_feature_dim, num_clusters)
        elif aggregation_method == 'pooling':
            self.pooling_aggregator = PoolingAggregator(grid_feature_dim, num_clusters)
    
    def forward(self, grid_features: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Forward pass with multiple outputs
        Args:
            grid_features: [B, M, D] grid features
        Returns:
            output_dict: Dictionary containing various aggregated features
        """
        # Main region aggregation
        regional_features, correlation_weights = self.region_aggregator(grid_features)
        
        output_dict = {
            'regional_features': regional_features,  # R^(1): [B, C, D]
            'correlation_weights': correlation_weights,  # r_ji^(0): [B, M, C]
            'cluster_assignments': self.region_aggregator.get_cluster_assignments(correlation_weights)
        }
        
        # Alternative aggregation methods
        if self.aggregation_method == 'attention':
            attention_features = self.attention_aggregator(grid_features)
            output_dict['attention_features'] = attention_features
        
        elif self.aggregation_method == 'pooling':
            pooled_features = self.pooling_aggregator(grid_features)
            output_dict['pooled_features'] = pooled_features
        
        return output_dict

class AttentionAggregator(nn.Module):
    """Alternative attention-based aggregation"""
    def __init__(self, feature_dim: int, num_regions: int):
        super().__init__()
        self.attention = nn.MultiheadAttention(feature_dim, num_heads=8, batch_first=True)
        self.region_queries = nn.Parameter(torch.randn(num_regions, feature_dim))
    
    def forward(self, grid_features: torch.Tensor) -> torch.Tensor:
        B, M, D = grid_features.shape
        
        # Expand region queries for batch
        queries = self.region_queries.unsqueeze(0).expand(B, -1, -1)  # [B, num_regions, D]
        
        # Apply attention
        aggregated_features, _ = self.attention(queries, grid_features, grid_features)
        
        return aggregated_features

class PoolingAggregator(nn.Module):
    """Alternative pooling-based aggregation"""
    def __init__(self, feature_dim: int, num_regions: int):
        super().__init__()
        self.num_regions = num_regions
        self.conv1d = nn.Conv1d(feature_dim, num_regions, kernel_size=1)
    
    def forward(self, grid_features: torch.Tensor) -> torch.Tensor:
        B, M, D = grid_features.shape
        
        # Transpose for conv1d: [B, D, M]
        features_transposed = grid_features.transpose(1, 2)
        
        # Apply 1D convolution and global average pooling
        conv_output = self.conv1d(features_transposed)  # [B, num_regions, M]
        pooled_output = F.adaptive_avg_pool1d(conv_output, 1).squeeze(-1)  # [B, num_regions]
        
        # Expand back to feature dimension
        pooled_features = pooled_output.unsqueeze(-1).expand(-1, -1, D)  # [B, num_regions, D]
        
        return pooled_features

# Testing and example usage
def test_region_aggregation_feature():
    """Test the Region Aggregation Feature implementation"""
    
    print("Testing Region Aggregation Feature Implementation")
    print("=" * 60)
    
    # Test parameters
    batch_size = 2
    num_grid_positions = 64  # 8x8 grid
    grid_feature_dim = 128
    num_clusters = 5
    
    # Create model
    model = RegionAggregationModule(
        grid_feature_dim=grid_feature_dim,
        num_regions=10,
        num_clusters=num_clusters,
        aggregation_method='attention'
    )
    
    # Create dummy grid features (from Section 3.1.1)
    grid_features = torch.randn(batch_size, num_grid_positions, grid_feature_dim)
    
    print(f"Input grid features shape: {grid_features.shape}")
    print(f"Grid feature dimension: {grid_feature_dim}")
    print(f"Number of clusters: {num_clusters}")
    print(f"Batch size: {batch_size}")
    
    # Forward pass
    with torch.no_grad():
        output_dict = model(grid_features)
    
    # Display results
    print(f"\nResults:")
    print(f"Regional features shape: {output_dict['regional_features'].shape}")
    print(f"Correlation weights shape: {output_dict['correlation_weights'].shape}")
    print(f"Cluster assignments shape: {output_dict['cluster_assignments'].shape}")
    
    if 'attention_features' in output_dict:
        print(f"Attention features shape: {output_dict['attention_features'].shape}")
    
    # Analyze clustering
    correlation_weights = output_dict['correlation_weights']
    cluster_assignments = output_dict['cluster_assignments']
    
    print(f"\nCluster Analysis:")
    for b in range(batch_size):
        for c in range(num_clusters):
            cluster_size = torch.sum(cluster_assignments[b] == c).item()
            avg_weight = torch.mean(correlation_weights[b, cluster_assignments[b] == c, c]).item()
            print(f"  Batch {b}, Cluster {c}: {cluster_size} positions, avg weight: {avg_weight:.3f}")
    
    # Test individual components
    print(f"\n" + "="*60)
    print("Testing Individual Components")
    print("="*60)
    
    # Test correlation weights computation
    region_aggregator = model.region_aggregator
    correlation_weights = region_aggregator.compute_correlation_weights(
        grid_features, region_aggregator.cluster_centers
    )
    print(f"Correlation weights range: [{correlation_weights.min():.3f}, {correlation_weights.max():.3f}]")
    print(f"Correlation weights sum per position (should be ~1): {correlation_weights.sum(dim=-1)[0, :5]}")
    
    # Test regional features computation
    regional_features = region_aggregator.compute_regional_features(grid_features, correlation_weights)
    print(f"Regional features shape: {regional_features.shape}")
    print(f"Regional features L2 norm: {torch.norm(regional_features, p=2, dim=-1)[0]}")
    
    # Test clustering loss
    clustering_loss = region_aggregator.compute_clustering_loss(grid_features, correlation_weights)
    print(f"Clustering loss: {clustering_loss.item():.4f}")
    
    # Test learnable parameters
    print(f"\nLearnable Parameters:")
    print(f"Cluster centers shape: {region_aggregator.cluster_centers.shape}")
    print(f"b_r parameter: {region_aggregator.b_r.item():.4f}")
    print(f"b_i parameter: {region_aggregator.b_i.item():.4f}")
    print(f"r_tilde parameter norm: {torch.norm(region_aggregator.r_tilde).item():.4f}")

if __name__ == "__main__":
    test_region_aggregation_feature()

Testing Region Aggregation Feature Implementation
Input grid features shape: torch.Size([2, 64, 128])
Grid feature dimension: 128
Number of clusters: 5
Batch size: 2


RuntimeError: The size of tensor a (2) must match the size of tensor b (64) at non-singleton dimension 1