In [1]:
pip install timm torch torchvision torch-geometric

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, List
import math

class SwinTransformerPatchEmbedding(nn.Module):
    """Simplified Swin Transformer patch embedding for feature extraction"""
    def __init__(self, patch_size: int = 4, embed_dim: int = 96):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: [B, C, H, W]
        x = self.proj(x)  # [B, embed_dim, H//patch_size, W//patch_size]
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, H*W, C]
        return x, (H, W)

class GridGraphFeature(nn.Module):
    """
    Grid-graph feature extraction based on the paper description
    """
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 4,
                 embed_dim: int = 96,
                 gcn_hidden_dim: int = 128):
        super().__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.gcn_hidden_dim = gcn_hidden_dim

        # Calculate number of patches
        self.num_patches_per_side = image_size // patch_size  # M = k = 56 for 224/4
        self.num_patches = self.num_patches_per_side ** 2     # M*k patches total

        # Swin Transformer for patch feature extraction
        self.swin_embedding = SwinTransformerPatchEmbedding(patch_size, embed_dim)

        # Graph Convolutional Layer
        self.gcn = GraphConvolutionalLayer(embed_dim, gcn_hidden_dim)

    def extract_patch_features(self, image: torch.Tensor) -> torch.Tensor:
        """
        Extract patch features using Swin Transformer
        Args:
            image: [B, 3, 224, 224] input image
        Returns:
            X^V: [B, M*k, embed_dim] patch feature matrix
        """
        # Split image into M x k patches and extract features
        patch_features, (H, W) = self.swin_embedding(image)  # [B, M*k, embed_dim]
        return patch_features

    def compute_similarity_matrix(self, patch_features: torch.Tensor) -> torch.Tensor:
        """
        Compute similarity matrix between patches
        Args:
            patch_features: [B, M*k, embed_dim] patch features
        Returns:
            A^V: [B, M*k, M*k] adjacency matrix
        """
        B, N, D = patch_features.shape
        k = self.num_patches_per_side

        # Normalize patch features for cosine similarity
        patch_features_norm = F.normalize(patch_features, p=2, dim=-1)

        # Compute cosine similarity matrix
        similarity_matrix = torch.bmm(patch_features_norm, patch_features_norm.transpose(-2, -1))

        # Initialize adjacency matrix
        adjacency_matrix = torch.zeros_like(similarity_matrix)

        # Create grid indices
        indices = torch.arange(N, device=patch_features.device).reshape(k, k)

        # Horizontal neighbors (right and left)
        right_neighbors = indices[:, :-1].flatten()  # i
        right_targets = indices[:, 1:].flatten()     # j
        left_neighbors = right_targets               # i
        left_targets = right_neighbors               # j

        # Vertical neighbors (down and up)
        down_neighbors = indices[:-1, :].flatten()   # i
        down_targets = indices[1:, :].flatten()      # j
        up_neighbors = down_targets                  # i
        up_targets = down_neighbors                  # j

        # Set adjacency for neighbors to 1
        adjacency_matrix[:, right_neighbors, right_targets] = 1.0
        adjacency_matrix[:, left_neighbors, left_targets] = 1.0
        adjacency_matrix[:, down_neighbors, down_targets] = 1.0
        adjacency_matrix[:, up_neighbors, up_targets] = 1.0

        # Set diagonal to 1
        adjacency_matrix[:, torch.arange(N), torch.arange(N)] = 1.0

        # Fill non-adjacent positions with similarity values
        mask = adjacency_matrix == 0
        adjacency_matrix[mask] = similarity_matrix[mask]

        return adjacency_matrix

    def forward(self, image: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of grid-graph feature extraction
        Args:
            image: [B, 3, 224, 224] input image
        Returns:
            V^(0): [B, M*k, hidden_dim] initial context grid representation
        """
        # Step 1: Extract patch features using Swin Transformer
        patch_features = self.extract_patch_features(image)  # X^V

        # Step 2: Compute similarity-based adjacency matrix
        adjacency_matrix = self.compute_similarity_matrix(patch_features)  # A^V

        # Step 3: Apply Graph Convolutional Layer
        grid_representation = self.gcn(patch_features, adjacency_matrix)  # V^(0)

        return grid_representation

class GraphConvolutionalLayer(nn.Module):
    """
    Graph Convolutional Layer implementation
    V^(0) = σ(Ã^T VW^T)
    """
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        # Weight matrix W
        self.weight = nn.Parameter(torch.FloatTensor(input_dim, output_dim))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, node_features: torch.Tensor, adjacency_matrix: torch.Tensor) -> torch.Tensor:
        """
        Args:
            node_features: [B, N, input_dim] node feature matrix V
            adjacency_matrix: [B, N, N] adjacency matrix A^V
        Returns:
            output: [B, N, output_dim] output features V^(0)
        """
        # Compute degree matrix D
        degree_matrix = torch.sum(adjacency_matrix, dim=-1, keepdim=True)  # [B, N, 1]
        degree_matrix = torch.clamp(degree_matrix, min=1.0)  # Avoid division by zero

        # Normalize adjacency matrix: Ã = D^(-1/2) A^V D^(-1/2)
        degree_inv_sqrt = torch.pow(degree_matrix, -0.5)
        normalized_adj = degree_inv_sqrt * adjacency_matrix * degree_inv_sqrt.transpose(-2, -1)

        # Apply weight transformation: VW^T
        transformed_features = torch.matmul(node_features, self.weight)  # [B, N, output_dim]

        # Graph convolution: Ã^T VW^T
        output = torch.bmm(normalized_adj.transpose(-2, -1), transformed_features)

        # Apply activation function (ReLU)
        output = F.relu(output)

        return output

# Example usage and testing
def test_grid_graph_feature():
    """Test the Grid-Graph Feature implementation"""

    # Create model
    model = GridGraphFeature(
        image_size=224,
        patch_size=4,
        embed_dim=96,
        gcn_hidden_dim=128
    )

    # Create dummy input
    batch_size = 2
    dummy_image = torch.randn(batch_size, 3, 224, 224)

    print(f"Input image shape: {dummy_image.shape}")
    print(f"Number of patches: {model.num_patches}")
    print(f"Patches per side: {model.num_patches_per_side}")

    # Forward pass
    with torch.no_grad():
        grid_representation = model(dummy_image)

    print(f"Output grid representation shape: {grid_representation.shape}")
    print(f"Expected shape: [{batch_size}, {model.num_patches}, {model.gcn_hidden_dim}]")

    # Test individual components
    print("\n--- Testing individual components ---")

    # Test patch feature extraction
    patch_features = model.extract_patch_features(dummy_image)
    print(f"Patch features shape: {patch_features.shape}")

    # Test similarity matrix computation
    similarity_matrix = model.compute_similarity_matrix(patch_features)
    print(f"Similarity matrix shape: {similarity_matrix.shape}")

    # Verify adjacency matrix properties
    print(f"Adjacency matrix diagonal sum: {torch.diagonal(similarity_matrix, dim1=-2, dim2=-1).sum()}")
    print(f"Adjacency matrix range: [{similarity_matrix.min():.3f}, {similarity_matrix.max():.3f}]")

if __name__ == "__main__":
    test_grid_graph_feature()

Input image shape: torch.Size([2, 3, 224, 224])
Number of patches: 3136
Patches per side: 56
Output grid representation shape: torch.Size([2, 3136, 128])
Expected shape: [2, 3136, 128]

--- Testing individual components ---
Patch features shape: torch.Size([2, 3136, 96])
Similarity matrix shape: torch.Size([2, 3136, 3136])
Adjacency matrix diagonal sum: 6272.0
Adjacency matrix range: [-0.720, 1.000]


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Tuple, List, Dict
import torchvision.ops as ops

class FasterRCNN(nn.Module):
    """Simplified Faster R-CNN for object region detection"""
    def __init__(self, backbone_dim: int = 512, num_classes: int = 80):
        super().__init__()
        self.backbone_dim = backbone_dim
        self.num_classes = num_classes

        # Simplified backbone (normally ResNet)
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, backbone_dim, 3, padding=1),
            nn.ReLU(inplace=True)
        )

        # ROI pooling output dimension
        self.roi_pool_size = 7
        self.roi_pooling = ops.RoIPool(output_size=(self.roi_pool_size, self.roi_pool_size), spatial_scale=1/16)

        # Feature extraction after ROI pooling
        self.roi_head = nn.Sequential(
            nn.Linear(backbone_dim * self.roi_pool_size * self.roi_pool_size, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True)
        )

    def forward(self, images: torch.Tensor, boxes: torch.Tensor) -> torch.Tensor:
        """
        Extract features for given region boxes
        Args:
            images: [B, 3, H, W] input images
            boxes: [B, P, 4] region boxes in format [x1, y1, x2, y2]
        Returns:
            region_features: [B, P, 512] features for each region
        """
        B, _, H, W = images.shape
        P = boxes.shape[1]

        # Extract backbone features
        backbone_features = self.backbone(images)  # [B, backbone_dim, H', W']

        # Ensure boxes are on the same device as images
        boxes = boxes.to(images.device)

        # Prepare boxes for ROI pooling
        batch_indices = torch.arange(B, device=images.device).view(B, 1).expand(B, P).reshape(-1)  # [B*P]
        flat_boxes = boxes.view(-1, 4)  # [B*P, 4]
        roi_boxes = torch.cat([batch_indices[:, None], flat_boxes], dim=1)  # [B*P, 5] [batch_idx, x1, y1, x2, y2]

        # ROI pooling
        pooled_features = self.roi_pooling(backbone_features, roi_boxes)  # [B*P, backbone_dim, 7, 7]

        # Flatten and process through ROI head
        pooled_flat = pooled_features.view(B * P, -1)  # [B*P, backbone_dim * 7 * 7]
        roi_features = self.roi_head(pooled_flat)  # [B*P, 512]

        # Reshape back to [B, P, 512]
        region_features = roi_features.view(B, P, 512)

        return region_features

class MultiHeadAttention(nn.Module):
    """Multi-Head Attention mechanism for region features"""
    def __init__(self, d_model: int, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Linear projections for Q, K, V
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # Output projection
        self.w_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
        """
        Args:
            query, key, value: [B, P, d_model] region features
        Returns:
            output: [B, P, d_model] attended features
        """
        B, P, d_model = query.shape

        # Linear projections
        Q = self.w_q(query)  # [B, P, d_model]
        K = self.w_k(key)    # [B, P, d_model]
        V = self.w_v(value)  # [B, P, d_model]

        # Reshape for multi-head attention
        Q = Q.view(B, P, self.num_heads, self.d_k).transpose(1, 2)  # [B, num_heads, P, d_k]
        K = K.view(B, P, self.num_heads, self.d_k).transpose(1, 2)  # [B, num_heads, P, d_k]
        V = V.view(B, P, self.num_heads, self.d_k).transpose(1, 2)  # [B, num_heads, P, d_k]

        # Scaled dot-product attention
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # [B, num_heads, P, P]
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Apply attention to values
        attended_values = torch.matmul(attention_weights, V)  # [B, num_heads, P, d_k]

        # Concatenate heads
        attended_values = attended_values.transpose(1, 2).contiguous().view(B, P, d_model)

        # Final linear projection
        output = self.w_o(attended_values)

        return output

class SelfAttention(nn.Module):
    """Self-Attention mechanism as described in the paper"""
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        self.scale = math.sqrt(d_model)

    def forward(self, R: torch.Tensor) -> torch.Tensor:
        """
        Compute QK^T Self-Attention
        Args:
            R: [B, P, d] region features where Q = K = V = R
        Returns:
            attention_output: [B, P, d] attended features
        """
        # Q = K = V = R
        Q = K = V = R  # [B, P, d]

        # Compute attention scores: QK^T / sqrt(d)
        attention_scores = torch.bmm(Q, K.transpose(-2, -1)) / self.scale  # [B, P, P]
        attention_weights = F.softmax(attention_scores, dim=-1)  # [B, P, P]
        attention_output = torch.bmm(attention_weights, V)  # [B, P, d]

        return attention_output

class RegionGraphFeature(nn.Module):
    """
    Region Graph Feature extraction using Faster R-CNN and Multi-Head Attention
    """
    def __init__(self,
                 num_regions: int = 100,
                 region_feature_dim: int = 512,
                 num_attention_heads: int = 8,
                 output_dim: int = 512):
        super().__init__()

        self.num_regions = num_regions
        self.region_feature_dim = region_feature_dim
        self.num_attention_heads = num_attention_heads
        self.output_dim = output_dim

        # Faster R-CNN for region detection and feature extraction
        self.faster_rcnn = FasterRCNN(backbone_dim=512, num_classes=80)

        # Self-Attention mechanism
        self.self_attention = SelfAttention(region_feature_dim)

        # Multi-Head Attention mechanism
        self.multi_head_attention = MultiHeadAttention(
            d_model=region_feature_dim,
            num_heads=num_attention_heads
        )

        # Final projection layer
        self.output_projection = nn.Linear(region_feature_dim, output_dim)

    def generate_dummy_boxes(self, batch_size: int, num_boxes: int, image_size: Tuple[int, int], device: torch.device) -> torch.Tensor:
        """
        Generate dummy bounding boxes for testing
        In practice, these would come from Faster R-CNN's RPN
        Args:
            batch_size: Number of images in batch
            num_boxes: Number of boxes per image
            image_size: (H, W) of input images
            device: Device to place tensors on (CPU or CUDA)
        """
        H, W = image_size

        # Vectorized box generation
        x1 = torch.randint(0, W//2, (batch_size, num_boxes), device=device).float()
        y1 = torch.randint(0, H//2, (batch_size, num_boxes), device=device).float()
        x2 = x1 + torch.randint(W//4, W//2, (batch_size, num_boxes), device=device).float()
        y2 = y1 + torch.randint(H//4, H//2, (batch_size, num_boxes), device=device).float()

        # Clamp boxes to image bounds
        x2 = torch.clamp(x2, max=W-1)
        y2 = torch.clamp(y2, max=H-1)

        # Stack coordinates: [x1, y1, x2, y2]
        boxes = torch.stack([x1, y1, x2, y2], dim=-1)  # [batch_size, num_boxes, 4]

        return boxes

    def forward(self, images: torch.Tensor, region_boxes: torch.Tensor = None) -> torch.Tensor:
        """
        Forward pass for region graph feature extraction
        Args:
            images: [B, 3, H, W] input images
            region_boxes: [B, P, 4] region boxes (optional, will generate dummy if None)
        Returns:
            R^(0): [B, P, output_dim] initial region representation
        """
        B, C, H, W = images.shape
        device = images.device  # Get device from input images

        # Generate dummy boxes if not provided
        if region_boxes is None:
            region_boxes = self.generate_dummy_boxes(B, self.num_regions, (H, W), device)
        else:
            region_boxes = region_boxes.to(device)

        # Enable mixed precision for faster computation on GPU
        with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):
            # Step 1: Extract region features using Faster R-CNN
            region_features = self.faster_rcnn(images, region_boxes)  # R: [B, P, d]

            # Step 2: Apply Self-Attention mechanism
            sa_output = self.self_attention(region_features)  # [B, P, d]

            # Step 3: Apply Multi-Head Attention
            mha_output = self.multi_head_attention(sa_output, sa_output, sa_output)  # [B, P, d]

            # Step 4: Combine with residual connection
            combined_features = region_features + mha_output  # [B, P, d]

            # Step 5: Final projection to get R^(0)
            initial_region_representation = self.output_projection(combined_features)  # [B, P, output_dim]

        return initial_region_representation

    def get_region_vectors(self, region_representation: torch.Tensor) -> List[torch.Tensor]:
        """
        Extract individual region vectors {r_i}
        Args:
            region_representation: [B, P, d] region features
        Returns:
            List of region vectors for each batch
        """
        B, P, d = region_representation.shape
        region_vectors = []

        for b in range(B):
            batch_vectors = [region_representation[b, p, :] for p in range(P)]
            region_vectors.append(batch_vectors)

        return region_vectors

# Example usage and testing
def test_region_graph_feature():
    """Test the Region Graph Feature implementation"""

    # Determine device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create model and move to device
    model = RegionGraphFeature(
        num_regions=50,
        region_feature_dim=512,
        num_attention_heads=8,
        output_dim=512
    ).to(device)

    # Create dummy input and move to device
    batch_size = 2
    dummy_images = torch.randn(batch_size, 3, 224, 224).to(device)

    print(f"Input images shape: {dummy_images.shape}")
    print(f"Number of regions: {model.num_regions}")
    print(f"Region feature dimension: {model.region_feature_dim}")

    # Forward pass
    with torch.no_grad():
        region_representation = model(dummy_images)

    print(f"Output region representation shape: {region_representation.shape}")
    print(f"Expected shape: [{batch_size}, {model.num_regions}, {model.output_dim}]")

    # Test with custom boxes
    print("\n--- Testing with custom region boxes ---")
    custom_boxes = torch.tensor([
        [[10, 10, 50, 50], [60, 60, 100, 100], [120, 30, 180, 90]],
        [[20, 20, 80, 80], [90, 10, 150, 70], [30, 100, 90, 160]]
    ], dtype=torch.float32).to(device)

    with torch.no_grad():
        custom_region_representation = model(dummy_images, custom_boxes)

    print(f"Custom region representation shape: {custom_region_representation.shape}")

    # Test individual components
    print("\n--- Testing individual components ---")

    # Test Faster R-CNN feature extraction
    dummy_boxes = model.generate_dummy_boxes(batch_size, 10, (224, 224), device)
    region_features = model.faster_rcnn(dummy_images, dummy_boxes)
    print(f"Faster R-CNN features shape: {region_features.shape}")

    # Test Self-Attention
    sa_output = model.self_attention(region_features)
    print(f"Self-Attention output shape: {sa_output.shape}")

    # Test Multi-Head Attention
    mha_output = model.multi_head_attention(region_features, region_features, region_features)
    print(f"Multi-Head Attention output shape: {mha_output.shape}")

    # Extract individual region vectors
    region_vectors = model.get_region_vectors(region_representation)
    print(f"Number of region vector batches: {len(region_vectors)}")
    print(f"Number of vectors per batch: {len(region_vectors[0])}")
    print(f"Each region vector shape: {region_vectors[0][0].shape}")

if __name__ == "__main__":
    test_region_graph_feature()

Using device: cuda
Input images shape: torch.Size([2, 3, 224, 224])
Number of regions: 50
Region feature dimension: 512


  with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):


Output region representation shape: torch.Size([2, 50, 512])
Expected shape: [2, 50, 512]

--- Testing with custom region boxes ---
Custom region representation shape: torch.Size([2, 3, 512])

--- Testing individual components ---
Faster R-CNN features shape: torch.Size([2, 10, 512])
Self-Attention output shape: torch.Size([2, 10, 512])
Multi-Head Attention output shape: torch.Size([2, 10, 512])
Number of region vector batches: 2
Number of vectors per batch: 50
Each region vector shape: torch.Size([512])


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
import numpy as np
import math
from typing import List, Dict, Tuple, Optional
import spacy
import networkx as nx

class BERTEmbedding(nn.Module):
    """BERT-based text embedding for semantic features"""
    def __init__(self, model_name: str = 'bert-base-uncased', embedding_dim: int = 768):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.bert_model = BertModel.from_pretrained(model_name)

        # Freeze BERT parameters for efficiency (optional)
        for param in self.bert_model.parameters():
            param.requires_grad = False

    def forward(self, text_sequences: List[str]) -> torch.Tensor:
        """
        Convert text sequences to BERT embeddings
        Args:
            text_sequences: List of text strings
        Returns:
            embeddings: [N, embedding_dim] where N is number of sequences
        """
        # Tokenize texts
        encoded = self.tokenizer(
            text_sequences,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )

        # Get BERT embeddings
        with torch.no_grad():
            outputs = self.bert_model(**encoded)
            # Use [CLS] token embedding as sentence representation
            embeddings = outputs.last_hidden_state[:, 0, :]  # [N, embedding_dim]

        return embeddings

class DependencyParser:
    """Dependency parsing for constructing semantic graphs"""
    def __init__(self):
        # Load spaCy model for dependency parsing
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except OSError:
            print("Warning: spaCy English model not found. Please run 'python -m spacy download en_core_web_sm' to install.")
            self.nlp = None

    def parse_text(self, text: str) -> List[Dict]:
        """
        Parse text and extract dependency relationships
        Args:
            text: Input text string
        Returns:
            dependencies: List of dependency relations
        """
        if self.nlp is None:
            # Dummy dependencies for testing
            words = text.split()
            return [{"head": i, "child": i+1, "relation": "dummy"} for i in range(len(words)-1)]

        doc = self.nlp(text)
        dependencies = []

        for token in doc:
            if token.head != token:  # Skip root
                dependencies.append({
                    "head": token.head.i,
                    "child": token.i,
                    "relation": token.dep_,
                    "head_text": token.head.text,
                    "child_text": token.text
                })

        return dependencies

    def extract_words(self, text: str) -> List[str]:
        """Extract individual words from text"""
        if self.nlp is None:
            return text.split()

        doc = self.nlp(text)
        return [token.text for token in doc if not token.is_punct]

class SemanticGraphConstructor:
    """Construct semantic graphs from dependency trees"""
    def __init__(self):
        self.dependency_parser = DependencyParser()

    def construct_adjacency_matrix(self, text: str) -> Tuple[torch.Tensor, List[str]]:
        """
        Construct adjacency matrix from dependency tree
        Args:
            text: Input text string
        Returns:
            adjacency_matrix: [N, N] adjacency matrix
            words: List of words corresponding to matrix indices
        """
        # Extract words and dependencies
        words = self.dependency_parser.extract_words(text)
        dependencies = self.dependency_parser.parse_text(text)

        N = len(words)
        if N == 0:
            return torch.zeros(1, 1), [""]

        # Initialize adjacency matrix
        adjacency_matrix = torch.zeros(N, N)

        # Fill adjacency matrix based on dependencies
        for dep in dependencies:
            head_idx = dep["head"]
            child_idx = dep["child"]

            # Ensure indices are within bounds
            if 0 <= head_idx < N and 0 <= child_idx < N:
                # Undirected graph: set both directions
                adjacency_matrix[head_idx, child_idx] = 1.0
                adjacency_matrix[child_idx, head_idx] = 1.0

        # Add self-connections (diagonal)
        for i in range(N):
            adjacency_matrix[i, i] = 1.0

        return adjacency_matrix, words

    def correlation_function(self, wi: str, wj: str) -> float:
        """
        Compute correlation between two words D(wi, wj)
        This is a simplified implementation - in practice, could use:
        - Word embeddings similarity
        - Co-occurrence statistics
        - Semantic similarity measures
        """
        if wi == wj:
            return 1.0

        # Jaccard similarity of character sets
        set_i = set(wi.lower())
        set_j = set(wj.lower())

        if len(set_i.union(set_j)) == 0:
            return 0.0

        correlation = len(set_i.intersection(set_j)) / len(set_i.union(set_j))
        return correlation

class SemanticGraphConvolutionalLayer(nn.Module):
    """
    Semantic Graph Convolutional Layer
    S^(0) = σ(Ã^S σ(Ã^S SW₁^S) W₂^S)
    """
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        # Weight matrices
        self.W1 = nn.Parameter(torch.FloatTensor(input_dim, hidden_dim))
        self.W2 = nn.Parameter(torch.FloatTensor(hidden_dim, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        stdv1 = 1. / math.sqrt(self.W1.size(1))
        self.W1.data.uniform_(-stdv1, stdv1)

        stdv2 = 1. / math.sqrt(self.W2.size(1))
        self.W2.data.uniform_(-stdv2, stdv2)

    def normalize_adjacency_matrix(self, adjacency_matrix: torch.Tensor) -> torch.Tensor:
        """
        Normalize adjacency matrix: Ã = D^(-1/2) A D^(-1/2)
        """
        # Compute degree matrix
        degree = torch.sum(adjacency_matrix, dim=-1, keepdim=True)  # [N, 1]
        degree = torch.clamp(degree, min=1.0)  # Avoid division by zero

        # D^(-1/2)
        degree_inv_sqrt = torch.pow(degree, -0.5)

        # Normalize: D^(-1/2) A D^(-1/2)
        normalized_adj = degree_inv_sqrt * adjacency_matrix * degree_inv_sqrt.transpose(-2, -1)

        return normalized_adj

    def forward(self, node_features: torch.Tensor, adjacency_matrix: torch.Tensor) -> torch.Tensor:
        """
        Args:
            node_features: [N, input_dim] node features S
            adjacency_matrix: [N, N] adjacency matrix A^S
        Returns:
            output: [N, output_dim] output features S^(0)
        """
        # Normalize adjacency matrix
        normalized_adj = self.normalize_adjacency_matrix(adjacency_matrix)  # Ã^S

        # First GCN layer: Ã^S S W₁^S
        h1 = torch.matmul(node_features, self.W1)  # [N, hidden_dim]
        h1 = torch.matmul(normalized_adj, h1)      # [N, hidden_dim]
        h1 = F.relu(h1)  # σ(Ã^S S W₁^S)

        # Second GCN layer: Ã^S σ(Ã^S S W₁^S) W₂^S
        h2 = torch.matmul(h1, self.W2)             # [N, output_dim]
        output = torch.matmul(normalized_adj, h2)   # [N, output_dim]
        output = F.relu(output)  # σ(Ã^S σ(Ã^S S W₁^S) W₂^S)

        return output

class SemanticGraphFeature(nn.Module):
    """
    Complete Semantic Graph Feature extraction pipeline
    """
    def __init__(self,
                 bert_model_name: str = 'bert-base-uncased',
                 bert_dim: int = 768,
                 gcn_hidden_dim: int = 512,
                 output_dim: int = 256):
        super().__init__()

        self.bert_dim = bert_dim
        self.gcn_hidden_dim = gcn_hidden_dim
        self.output_dim = output_dim

        # BERT embedding for word-level features
        self.bert_embedding = BERTEmbedding(bert_model_name, bert_dim)

        # Semantic graph constructor
        self.graph_constructor = SemanticGraphConstructor()

        # Semantic Graph Convolutional Layer
        self.semantic_gcn = SemanticGraphConvolutionalLayer(
            input_dim=bert_dim,
            hidden_dim=gcn_hidden_dim,
            output_dim=output_dim
        )

    def process_single_text(self, text: str) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Process a single text to extract semantic graph features
        Args:
            text: Input text string
        Returns:
            semantic_features: [N, output_dim] semantic node features
            adjacency_matrix: [N, N] semantic adjacency matrix
        """
        # Step 1: Construct semantic graph
        adjacency_matrix, words = self.graph_constructor.construct_adjacency_matrix(text)

        if len(words) == 0 or words == [""]:
            # Handle empty text
            return torch.zeros(1, self.output_dim), torch.zeros(1, 1)

        # Step 2: Get BERT embeddings for words
        word_embeddings = self.bert_embedding(words)  # [N, bert_dim]

        # Step 3: Apply Semantic GCN
        semantic_features = self.semantic_gcn(word_embeddings, adjacency_matrix)  # [N, output_dim]

        return semantic_features, adjacency_matrix

    def forward(self, text_list: List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        Process multiple texts
        Args:
            text_list: List of text strings
        Returns:
            semantic_features_list: List of [Ni, output_dim] tensors
            adjacency_matrices_list: List of [Ni, Ni] tensors
        """
        semantic_features_list = []
        adjacency_matrices_list = []

        for text in text_list:
            semantic_features, adjacency_matrix = self.process_single_text(text)
            semantic_features_list.append(semantic_features)
            adjacency_matrices_list.append(adjacency_matrix)

        return semantic_features_list, adjacency_matrices_list

    def get_semantic_vectors(self, semantic_features_list: List[torch.Tensor]) -> List[List[torch.Tensor]]:
        """
        Extract individual semantic vectors {s_i}
        Args:
            semantic_features_list: List of [Ni, output_dim] tensors
        Returns:
            List of lists containing individual semantic vectors
        """
        semantic_vectors = []

        for semantic_features in semantic_features_list:
            N, d = semantic_features.shape
            text_vectors = [semantic_features[i, :] for i in range(N)]
            semantic_vectors.append(text_vectors)

        return semantic_vectors

# Example usage and testing
def test_semantic_graph_feature():
    """Test the Semantic Graph Feature implementation"""

    print("Testing Semantic Graph Feature Implementation")
    print("=" * 50)

    # Create model
    model = SemanticGraphFeature(
        bert_model_name='bert-base-uncased',
        bert_dim=768,
        gcn_hidden_dim=512,
        output_dim=256
    )

    # Test texts (image captions)
    test_texts = [
        "A cat sitting on a wooden table",
        "Two dogs playing in the park",
        "Beautiful sunset over the ocean",
        "Person riding bicycle on street"
    ]

    print(f"Test texts: {len(test_texts)} captions")
    for i, text in enumerate(test_texts):
        print(f"  {i+1}. {text}")

    # Process texts
    semantic_features_list, adjacency_matrices_list = model(test_texts)

    print(f"\nResults:")
    print(f"Number of processed texts: {len(semantic_features_list)}")

    for i, (features, adj_matrix) in enumerate(zip(semantic_features_list, adjacency_matrices_list)):
        print(f"\nText {i+1}: '{test_texts[i]}'")
        print(f"  Semantic features shape: {features.shape}")
        print(f"  Adjacency matrix shape: {adj_matrix.shape}")
        print(f"  Number of words/nodes: {features.shape[0]}")
        print(f"  Feature dimension: {features.shape[1]}")
        print(f"  Graph density: {(adj_matrix.sum() - adj_matrix.trace()) / (adj_matrix.numel() - adj_matrix.shape[0]):.3f}")

    # Test individual components
    print(f"\n" + "="*50)
    print("Testing Individual Components")
    print("="*50)

    # Test dependency parsing
    test_text = "A cat sitting on a wooden table"
    dependencies = model.graph_constructor.dependency_parser.parse_text(test_text)
    words = model.graph_constructor.dependency_parser.extract_words(test_text)

    print(f"\nDependency parsing for: '{test_text}'")
    print(f"Words: {words}")
    print(f"Dependencies: {len(dependencies)}")
    for dep in dependencies[:3]:  # Show first 3
        print(f"  {dep}")

    # Test adjacency matrix construction
    adj_matrix, extracted_words = model.graph_constructor.construct_adjacency_matrix(test_text)
    print(f"\nAdjacency matrix shape: {adj_matrix.shape}")
    print(f"Extracted words: {extracted_words}")
    print(f"Adjacency matrix:\n{adj_matrix}")

    # Test BERT embeddings
    word_embeddings = model.bert_embedding(words[:3])  # Test first 3 words
    print(f"\nBERT embeddings shape for 3 words: {word_embeddings.shape}")
    print(f"Embedding dimension: {word_embeddings.shape[1]}")

    # Test semantic vectors extraction
    semantic_vectors = model.get_semantic_vectors(semantic_features_list)
    print(f"\nSemantic vectors:")
    print(f"Number of texts: {len(semantic_vectors)}")
    if len(semantic_vectors) > 0:
        print(f"Vectors in first text: {len(semantic_vectors[0])}")
        if len(semantic_vectors[0]) > 0:
            print(f"Each vector shape: {semantic_vectors[0][0].shape}")

if __name__ == "__main__":
    test_semantic_graph_feature()

2025-06-08 16:29:01.561898: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749400141.764408      93 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749400141.820639      93 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Testing Semantic Graph Feature Implementation


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Test texts: 4 captions
  1. A cat sitting on a wooden table
  2. Two dogs playing in the park
  3. Beautiful sunset over the ocean
  4. Person riding bicycle on street

Results:
Number of processed texts: 4

Text 1: 'A cat sitting on a wooden table'
  Semantic features shape: torch.Size([7, 256])
  Adjacency matrix shape: torch.Size([7, 7])
  Number of words/nodes: 7
  Feature dimension: 256
  Graph density: 0.286

Text 2: 'Two dogs playing in the park'
  Semantic features shape: torch.Size([6, 256])
  Adjacency matrix shape: torch.Size([6, 6])
  Number of words/nodes: 6
  Feature dimension: 256
  Graph density: 0.333

Text 3: 'Beautiful sunset over the ocean'
  Semantic features shape: torch.Size([5, 256])
  Adjacency matrix shape: torch.Size([5, 5])
  Number of words/nodes: 5
  Feature dimension: 256
  Graph density: 0.400

Text 4: 'Person riding bicycle on street'
  Semantic features shape: torch.Size([5, 256])
  Adjacency matrix shape: torch.Size([5, 5])
  Number of words/nodes: 5


**Region-grid aggregator**

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Tuple, List, Dict

class RegionGridAggregator(nn.Module):
    """
    Region-Grid Aggregator that combines grid features with region features
    to create enhanced region representations
    """
    def __init__(self, 
                 grid_feature_dim: int = 128,
                 region_feature_dim: int = 512,
                 hidden_dim: int = 256,
                 num_regions: int = 100):
        super().__init__()
        
        self.grid_feature_dim = grid_feature_dim
        self.region_feature_dim = region_feature_dim
        self.hidden_dim = hidden_dim
        self.num_regions = num_regions
        
        # Learnable parameters for attention calculation
        self.b_i = nn.Parameter(torch.randn(1))
        self.b_j = nn.Parameter(torch.randn(1))
        
        # Learnable parameter r_tilde for normalization
        self.r_tilde = nn.Parameter(torch.randn(grid_feature_dim))
        
        # Projection layer to align dimensions
        self.grid_projection = nn.Linear(grid_feature_dim, region_feature_dim)
        
    def compute_region_grid_correlation(self, 
                                      grid_features: torch.Tensor,
                                      region_centers: torch.Tensor) -> torch.Tensor:
        """
        Compute correlation between region centers and grid features
        Args:
            grid_features: [B, M*k, grid_dim] from grid-graph feature
            region_centers: [B, P, region_dim] potential region centers
        Returns:
            correlation_matrix: [B, P, M*k] correlation scores
        """
        B, num_patches, grid_dim = grid_features.shape
        B, num_regions, region_dim = region_centers.shape
        
        # Ensure consistent dtype
        grid_features = grid_features.float()
        region_centers = region_centers.float()
        
        # Project grid features to same dimension as region features
        grid_projected = self.grid_projection(grid_features)  # [B, M*k, region_dim]
        
        # Compute dot product correlation
        correlation = torch.bmm(region_centers, grid_projected.transpose(-2, -1))  # [B, P, M*k]
        
        # Add learnable bias terms (ensure same dtype)
        correlation = correlation + self.b_i.float() + self.b_j.float()  # Broadcasting
        
        return correlation
    
    def compute_attention_weights(self, correlation_matrix: torch.Tensor) -> torch.Tensor:
        """
        Compute attention weights using softmax normalization
        Args:
            correlation_matrix: [B, P, M*k] correlation scores
        Returns:
            attention_weights: [B, P, M*k] normalized attention weights
        """
        # Apply softmax over grid features dimension (M*k)
        attention_weights = F.softmax(correlation_matrix, dim=-1)  # [B, P, M*k]
        
        return attention_weights
    
    def aggregate_grid_to_region(self, 
                                grid_features: torch.Tensor,
                                attention_weights: torch.Tensor) -> torch.Tensor:
        """
        Aggregate grid features to region features using attention weights
        Args:
            grid_features: [B, M*k, grid_dim] grid features
            attention_weights: [B, P, M*k] attention weights
        Returns:
            aggregated_features: [B, P, grid_dim] aggregated region features
        """
        # Weighted sum of grid features
        aggregated_features = torch.bmm(attention_weights, grid_features)  # [B, P, grid_dim]
        
        return aggregated_features
    
    def compute_region_aggregation_feature(self, 
                                         aggregated_features: torch.Tensor,
                                         attention_weights: torch.Tensor) -> torch.Tensor:
        """
        Compute final region aggregation feature with L2 normalization
        Args:
            aggregated_features: [B, P, grid_dim] aggregated features
            attention_weights: [B, P, M*k] attention weights
        Returns:
            region_agg_features: [B, P, grid_dim] final region aggregation features
        """
        B, P, grid_dim = aggregated_features.shape
        
        # Ensure consistent dtype
        aggregated_features = aggregated_features.float()
        
        # Expand r_tilde for batch processing
        r_tilde_expanded = self.r_tilde.float().unsqueeze(0).unsqueeze(0).expand(B, P, -1)  # [B, P, grid_dim]
        
        # Compute difference from learnable parameter
        diff = aggregated_features - r_tilde_expanded  # [B, P, grid_dim]
        
        # Apply L2 norm standardization
        region_agg_features = F.normalize(diff, p=2, dim=-1)  # [B, P, grid_dim]
        
        return region_agg_features

    def forward(self, 
                grid_features: torch.Tensor,
                region_centers: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of Region-Grid Aggregator
        Args:
            grid_features: [B, M*k, grid_dim] from GridGraphFeature
            region_centers: [B, P, region_dim] potential region centers
        Returns:
            region_aggregation_features: [B, P, grid_dim] enhanced region features
        """
        # Step 1: Compute correlation between regions and grid features
        correlation_matrix = self.compute_region_grid_correlation(grid_features, region_centers)
        
        # Step 2: Compute attention weights
        attention_weights = self.compute_attention_weights(correlation_matrix)
        
        # Step 3: Aggregate grid features to regions
        aggregated_features = self.aggregate_grid_to_region(grid_features, attention_weights)
        
        # Step 4: Compute final region aggregation features
        region_aggregation_features = self.compute_region_aggregation_feature(aggregated_features, attention_weights)
        
        return region_aggregation_features


class MultiModalGraphAttention(nn.Module):
    """
    Multi-Modal Graph Attention that combines Grid-Graph and Region-Graph features
    """
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 4,
                 grid_embed_dim: int = 96,
                 grid_hidden_dim: int = 128,
                 num_regions: int = 100,
                 region_feature_dim: int = 512,
                 num_attention_heads: int = 8,
                 fusion_dim: int = 512):
        super().__init__()
        
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_regions = num_regions
        self.fusion_dim = fusion_dim
        
        # Import the previously defined modules
        
        # Grid-Graph Feature Extractor
        self.grid_graph_extractor = GridGraphFeature(
            image_size=image_size,
            patch_size=patch_size,
            embed_dim=grid_embed_dim,
            gcn_hidden_dim=grid_hidden_dim
        )
        
        # Region-Graph Feature Extractor
        self.region_graph_extractor = RegionGraphFeature(
            num_regions=num_regions,
            region_feature_dim=region_feature_dim,
            num_attention_heads=num_attention_heads,
            output_dim=region_feature_dim
        )
        
        # Region-Grid Aggregator
        self.region_grid_aggregator = RegionGridAggregator(
            grid_feature_dim=grid_hidden_dim,
            region_feature_dim=region_feature_dim,
            hidden_dim=fusion_dim,
            num_regions=num_regions
        )
        
        # Feature fusion layers
        self.region_fusion = nn.Sequential(
            nn.Linear(region_feature_dim + grid_hidden_dim, fusion_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(fusion_dim, fusion_dim)
        )
        
        self.grid_fusion = nn.Sequential(
            nn.Linear(grid_hidden_dim, fusion_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        )
        
        # Cross-modal attention
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=fusion_dim,
            num_heads=num_attention_heads,
            dropout=0.1,
            batch_first=True
        )
        
        # Final output projection
        self.output_projection = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(fusion_dim, fusion_dim)
        )
    
    def forward(self, images: torch.Tensor, region_boxes: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        """
        Forward pass of Multi-Modal Graph Attention
        Args:
            images: [B, 3, H, W] input images
            region_boxes: [B, P, 4] region boxes (optional)
        Returns:
            Dictionary containing:
                - 'fused_features': [B, P+M*k, fusion_dim] fused multi-modal features
                - 'grid_features': [B, M*k, fusion_dim] processed grid features
                - 'region_features': [B, P, fusion_dim] enhanced region features
                - 'region_aggregation': [B, P, grid_hidden_dim] region-grid aggregation
        """
        B = images.shape[0]
        device = images.device
        
        # Disable mixed precision to avoid dtype issues
        with torch.cuda.amp.autocast(enabled=False):
            # Ensure input is float32
            images = images.float()
            if region_boxes is not None:
                region_boxes = region_boxes.float()
            
            # Step 1: Extract Grid-Graph Features
            grid_features = self.grid_graph_extractor(images)  # [B, M*k, grid_hidden_dim]
            
            # Step 2: Extract Region-Graph Features
            region_features = self.region_graph_extractor(images, region_boxes)  # [B, P, region_feature_dim]
            
            # Step 3: Apply Region-Grid Aggregator
            region_aggregation = self.region_grid_aggregator(grid_features, region_features)  # [B, P, grid_hidden_dim]
            
            # Step 4: Fuse region features with aggregated grid features
            concatenated_region = torch.cat([region_features, region_aggregation], dim=-1)  # [B, P, region_dim + grid_dim]
            enhanced_region_features = self.region_fusion(concatenated_region)  # [B, P, fusion_dim]
            
            # Step 5: Process grid features
            processed_grid_features = self.grid_fusion(grid_features)  # [B, M*k, fusion_dim]
            
            # Step 6: Apply cross-modal attention between regions and grids
            # Region features attend to grid features
            region_attended, _ = self.cross_attention(
                enhanced_region_features,  # Query: regions
                processed_grid_features,   # Key: grids
                processed_grid_features    # Value: grids
            )  # [B, P, fusion_dim]
            
            # Grid features attend to region features
            grid_attended, _ = self.cross_attention(
                processed_grid_features,   # Query: grids
                enhanced_region_features,  # Key: regions
                enhanced_region_features   # Value: regions
            )  # [B, M*k, fusion_dim]
            
            # Step 7: Combine attended features
            final_region_features = torch.cat([enhanced_region_features, region_attended], dim=-1)  # [B, P, 2*fusion_dim]
            final_grid_features = torch.cat([processed_grid_features, grid_attended], dim=-1)  # [B, M*k, 2*fusion_dim]
            
            # Step 8: Final projection
            final_region_features = self.output_projection(final_region_features)  # [B, P, fusion_dim]
            final_grid_features = self.output_projection(final_grid_features)  # [B, M*k, fusion_dim]
            
            # Step 9: Concatenate all features for final representation
            fused_features = torch.cat([final_region_features, final_grid_features], dim=1)  # [B, P+M*k, fusion_dim]
        
        return {
            'fused_features': fused_features,
            'grid_features': final_grid_features,
            'region_features': final_region_features,
            'region_aggregation': region_aggregation
        }


# Example usage and testing
def test_mmgat():
    """Test the Multi-Modal Graph Attention implementation"""
    
    # Determine device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create model
    model = MultiModalGraphAttention(
        image_size=224,
        patch_size=4,
        grid_embed_dim=96,
        grid_hidden_dim=128,
        num_regions=50,
        region_feature_dim=512,
        num_attention_heads=8,
        fusion_dim=512
    ).to(device)
    
    # Create dummy input
    batch_size = 2
    dummy_images = torch.randn(batch_size, 3, 224, 224).to(device)
    
    print(f"Input images shape: {dummy_images.shape}")
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Forward pass
    with torch.no_grad():
        outputs = model(dummy_images)
    
    # Print output shapes
    print("\n--- Output Shapes ---")
    for key, value in outputs.items():
        print(f"{key}: {value.shape}")
    
    # Calculate expected dimensions
    num_patches = (224 // 4) ** 2  # 56 * 56 = 3136
    num_regions = 50
    total_features = num_regions + num_patches
    
    print(f"\n--- Expected Dimensions ---")
    print(f"Number of patches: {num_patches}")
    print(f"Number of regions: {num_regions}")
    print(f"Total features: {total_features}")
    print(f"Fusion dimension: 512")
    
    # Test with custom region boxes
    print("\n--- Testing with custom region boxes ---")
    custom_boxes = torch.tensor([
        [[10, 10, 50, 50], [60, 60, 100, 100], [120, 30, 180, 90]],
        [[20, 20, 80, 80], [90, 10, 150, 70], [30, 100, 90, 160]]
    ], dtype=torch.float32).to(device)
    
    # Note: This would require adjusting num_regions to 3 for this test
    # For demonstration, we'll skip this test
    print("Custom box testing skipped (requires model reconfiguration)")
    
    print("\n--- MMGAT Test Completed Successfully ---")


if __name__ == "__main__":
    test_mmgat()

Using device: cuda
Input images shape: torch.Size([2, 3, 224, 224])
Model parameters: 31,666,018


  with torch.cuda.amp.autocast(enabled=False):
  with torch.cuda.amp.autocast(enabled=device.type == 'cuda'):



--- Output Shapes ---
fused_features: torch.Size([2, 3186, 512])
grid_features: torch.Size([2, 3136, 512])
region_features: torch.Size([2, 50, 512])
region_aggregation: torch.Size([2, 50, 128])

--- Expected Dimensions ---
Number of patches: 3136
Number of regions: 50
Total features: 3186
Fusion dimension: 512

--- Testing with custom region boxes ---
Custom box testing skipped (requires model reconfiguration)

--- MMGAT Test Completed Successfully ---


**Grid-semantic aggregator**

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import List, Tuple, Dict
import numpy as np

class GridSemanticAggregator(nn.Module):
    """
    Grid-Semantic Aggregator that combines grid-graph and semantic-graph features
    Based on the paper's dual objectives:
    1) Refine semantics nodes utilizing the visual context
    2) Enhance visual nodes with contextual semantics
    """
    def __init__(self,
                 grid_feature_dim: int = 128,      # Output dim from GridGraphFeature
                 semantic_feature_dim: int = 256,   # Output dim from SemanticGraphFeature
                 lstm_hidden_dim: int = 256,
                 mlp_hidden_dim: int = 512,
                 final_output_dim: int = 512):
        super().__init__()
        
        self.grid_feature_dim = grid_feature_dim
        self.semantic_feature_dim = semantic_feature_dim
        self.lstm_hidden_dim = lstm_hidden_dim
        self.mlp_hidden_dim = mlp_hidden_dim
        self.final_output_dim = final_output_dim
        
        # Bidirectional LSTM for processing region aggregation features R^(1)
        self.bidirectional_lstm = nn.LSTM(
            input_size=grid_feature_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        
        # MLP layers for encoding different node types
        self.f_s = nn.Sequential(  # For encoding semantic nodes s_i^(0)
            nn.Linear(semantic_feature_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.ReLU()
        )
        
        self.f_c = nn.Sequential(  # For encoding context grid features
            nn.Linear(lstm_hidden_dim * 2, mlp_hidden_dim),  # *2 for bidirectional
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.ReLU()
        )
        
        self.f_v = nn.Sequential(  # For encoding visual series features
            nn.Linear(grid_feature_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.ReLU()
        )
        
        # MLP for correlation score computation
        self.correlation_mlp = nn.Sequential(
            nn.Linear(mlp_hidden_dim * 2, mlp_hidden_dim),  # Concatenated features
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, 1)
        )
        
        # MLPs for adjacent node encoding (f_h for semantic nodes)
        self.f_h = nn.Sequential(  # For encoding adjacent semantic nodes
            nn.Linear(semantic_feature_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.ReLU()
        )
        
        # MLPs for final node encoding (f_g for visual nodes)
        self.f_g = nn.Sequential(  # For encoding s_j^(0) when updating visual nodes
            nn.Linear(semantic_feature_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.ReLU()
        )
        
        # Final projection layers
        self.semantic_projection = nn.Linear(semantic_feature_dim + mlp_hidden_dim, final_output_dim)
        self.visual_projection = nn.Linear(grid_feature_dim + mlp_hidden_dim, final_output_dim)
        
    def process_grid_features_with_lstm(self, grid_features: torch.Tensor) -> torch.Tensor:
        """
        Process grid features through bidirectional LSTM to get R^(1)
        Args:
            grid_features: [B, M*k, grid_feature_dim] from GridGraphFeature
        Returns:
            lstm_features: [B, M*k, lstm_hidden_dim*2] processed features R^(1)
        """
        # Apply bidirectional LSTM
        lstm_output, _ = self.bidirectional_lstm(grid_features)  # [B, M*k, lstm_hidden_dim*2]
        return lstm_output
    
    def compute_correlation_scores(self, 
                                  semantic_features: torch.Tensor,
                                  grid_context_features: torch.Tensor) -> torch.Tensor:
        """
        Compute correlation scores α_{s_i,r_j} between semantic and grid features
        Args:
            semantic_features: [N_s, semantic_feature_dim] semantic node features s_i^(0)
            grid_context_features: [M*k, lstm_hidden_dim*2] grid context features from LSTM
        Returns:
            correlation_scores: [N_s, M*k] correlation scores α_{s_i,r_j}
        """
        N_s = semantic_features.shape[0]
        M_k = grid_context_features.shape[0]
        
        # Encode features
        encoded_semantic = self.f_s(semantic_features)  # [N_s, mlp_hidden_dim]
        encoded_grid = self.f_c(grid_context_features)   # [M*k, mlp_hidden_dim]
        
        # Compute pairwise correlation scores
        correlation_scores = torch.zeros(N_s, M_k, device=semantic_features.device)
        
        for i in range(N_s):
            for j in range(M_k):
                # Concatenate features for correlation computation
                concat_features = torch.cat([
                    encoded_semantic[i:i+1], 
                    encoded_grid[j:j+1]
                ], dim=1)  # [1, mlp_hidden_dim*2]
                
                # Compute correlation score through MLP
                score = self.correlation_mlp(concat_features)  # [1, 1]
                correlation_scores[i, j] = score.squeeze()
        
        return correlation_scores
    
    def apply_attention_aggregation(self, 
                                   correlation_scores: torch.Tensor,
                                   target_features: torch.Tensor) -> torch.Tensor:
        """
        Apply attention-based aggregation using correlation scores
        Args:
            correlation_scores: [N_source, N_target] correlation scores
            target_features: [N_target, feature_dim] features to aggregate
        Returns:
            aggregated_features: [N_source, feature_dim] aggregated features
        """
        # Apply softmax to get attention weights
        attention_weights = F.softmax(correlation_scores, dim=1)  # [N_source, N_target]
        
        # Weighted aggregation
        aggregated_features = torch.matmul(attention_weights, target_features)  # [N_source, feature_dim]
        
        return aggregated_features
    
    def refine_semantic_nodes(self, 
                             semantic_features: torch.Tensor,
                             semantic_adj_matrix: torch.Tensor,
                             grid_features: torch.Tensor) -> torch.Tensor:
        """
        Refine semantic nodes using visual context (Objective 1)
        Implements equations (11)-(14) from the paper
        Args:
            semantic_features: [N_s, semantic_feature_dim] semantic node features S^(0)
            semantic_adj_matrix: [N_s, N_s] semantic adjacency matrix A^S
            grid_features: [M*k, grid_feature_dim] grid features V^(0)
        Returns:
            refined_semantic: [N_s, final_output_dim] refined semantic features S^(2)
        """
        # Step 1: Process grid features through LSTM to get R^(1)
        grid_features_expanded = grid_features.unsqueeze(0)  # Add batch dim
        lstm_features = self.process_grid_features_with_lstm(grid_features_expanded)  # [1, M*k, lstm_hidden_dim*2]
        lstm_features = lstm_features.squeeze(0)  # Remove batch dim: [M*k, lstm_hidden_dim*2]
        
        # Step 2: Compute correlation scores α_{s_i,r_j}
        correlation_scores = self.compute_correlation_scores(semantic_features, lstm_features)  # [N_s, M*k]
        
        # Step 3: Aggregate visual context for each semantic node
        visual_context = self.apply_attention_aggregation(correlation_scores, lstm_features)  # [N_s, lstm_hidden_dim*2]
        
        # Step 4: Update semantic representation s_i^(2) = [s_i^(0); Σ α_{s_i,r_j} f_h(r_j^(1))]
        # Encode visual context
        encoded_visual_context = self.f_h(
            torch.cat([visual_context, torch.zeros(visual_context.shape[0], 
                      max(0, self.semantic_feature_dim - visual_context.shape[1]), 
                      device=visual_context.device)], dim=1)[:, :self.semantic_feature_dim]
        )  # [N_s, mlp_hidden_dim]
        
        # Concatenate original semantic features with visual context
        enhanced_semantic = torch.cat([semantic_features, encoded_visual_context], dim=1)  # [N_s, semantic_feature_dim + mlp_hidden_dim]
        
        # Project to final dimension
        refined_semantic = self.semantic_projection(enhanced_semantic)  # [N_s, final_output_dim]
        
        return refined_semantic
    
    def enhance_visual_nodes(self,
                           grid_features: torch.Tensor,
                           semantic_features: torch.Tensor,
                           semantic_adj_matrix: torch.Tensor) -> torch.Tensor:
        """
        Enhance visual nodes with contextual semantics (Objective 2)
        Implements equations (15)-(16) from the paper
        Args:
            grid_features: [M*k, grid_feature_dim] grid features V^(0)
            semantic_features: [N_s, semantic_feature_dim] semantic node features S^(0)
            semantic_adj_matrix: [N_s, N_s] semantic adjacency matrix A^S
        Returns:
            enhanced_visual: [M*k, final_output_dim] enhanced visual features R^(2)
        """
        M_k = grid_features.shape[0]
        
        # Step 1: Process grid features to get visual series features
        encoded_visual = self.f_v(grid_features)  # [M*k, mlp_hidden_dim]
        
        # Step 2: Compute correlation scores α_{r_i,s_j} (reverse direction)
        # We need to compute correlation between each visual node and semantic nodes
        correlation_scores = torch.zeros(M_k, semantic_features.shape[0], device=grid_features.device)
        
        for i in range(M_k):
            for j in range(semantic_features.shape[0]):
                # Concatenate features for correlation computation
                concat_features = torch.cat([
                    encoded_visual[i:i+1],
                    self.f_g(semantic_features[j:j+1])
                ], dim=1)  # [1, mlp_hidden_dim*2]
                
                # Compute correlation score
                score = self.correlation_mlp(concat_features)  # [1, 1]
                correlation_scores[i, j] = score.squeeze()
        
        # Step 3: Aggregate semantic context for each visual node
        semantic_context = self.apply_attention_aggregation(correlation_scores, semantic_features)  # [M*k, semantic_feature_dim]
        
        # Step 4: Update visual representation r_i^(2) = [r_i^(1); Σ α_{r_i,s_j} f_g(s_j^(0))]
        # Encode semantic context
        encoded_semantic_context = self.f_g(semantic_context)  # [M*k, mlp_hidden_dim]
        
        # Concatenate original grid features with semantic context
        enhanced_visual_features = torch.cat([grid_features, encoded_semantic_context], dim=1)  # [M*k, grid_feature_dim + mlp_hidden_dim]
        
        # Project to final dimension  
        enhanced_visual = self.visual_projection(enhanced_visual_features)  # [M*k, final_output_dim]
        
        return enhanced_visual
    
    def forward(self, 
                grid_features: torch.Tensor,
                semantic_features_list: List[torch.Tensor],
                semantic_adj_matrices_list: List[torch.Tensor]) -> Tuple[List[torch.Tensor], torch.Tensor]:
        """
        Forward pass of Grid-Semantic Aggregator
        Args:
            grid_features: [B, M*k, grid_feature_dim] grid features from GridGraphFeature
            semantic_features_list: List of [N_i, semantic_feature_dim] semantic features
            semantic_adj_matrices_list: List of [N_i, N_i] semantic adjacency matrices
        Returns:
            refined_semantic_list: List of refined semantic features
            enhanced_visual: [B, M*k, final_output_dim] enhanced visual features
        """
        batch_size = grid_features.shape[0]
        refined_semantic_list = []
        enhanced_visual_list = []
        
        for b in range(batch_size):
            # Process each sample in the batch
            current_grid = grid_features[b]  # [M*k, grid_feature_dim]
            
            # Handle multiple semantic graphs per image (if any)
            if b < len(semantic_features_list):
                current_semantic = semantic_features_list[b]  # [N_s, semantic_feature_dim]
                current_semantic_adj = semantic_adj_matrices_list[b]  # [N_s, N_s]
                
                # Refine semantic nodes using visual context
                refined_semantic = self.refine_semantic_nodes(
                    current_semantic, current_semantic_adj, current_grid
                )
                refined_semantic_list.append(refined_semantic)
                
                # Enhance visual nodes with semantic context
                enhanced_visual = self.enhance_visual_nodes(
                    current_grid, current_semantic, current_semantic_adj
                )
                enhanced_visual_list.append(enhanced_visual)
        
        # Stack enhanced visual features
        if enhanced_visual_list:
            enhanced_visual_tensor = torch.stack(enhanced_visual_list, dim=0)  # [B, M*k, final_output_dim]
        else:
            # Fallback if no semantic features available
            enhanced_visual_tensor = grid_features
        
        return refined_semantic_list, enhanced_visual_tensor

class MMGATImageCaptioning(nn.Module):
    """
    Complete MMGAT Image Captioning model that integrates Grid and Semantic features
    """
    def __init__(self,
                 grid_graph_model,      # GridGraphFeature instance
                 semantic_graph_model,  # SemanticGraphFeature instance
                 aggregator_config: Dict = None):
        super().__init__()
        
        self.grid_graph_model = grid_graph_model
        self.semantic_graph_model = semantic_graph_model
        
        # Default aggregator configuration
        if aggregator_config is None:
            aggregator_config = {
                'grid_feature_dim': 128,
                'semantic_feature_dim': 256,
                'lstm_hidden_dim': 256,
                'mlp_hidden_dim': 512,
                'final_output_dim': 512
            }
        
        self.grid_semantic_aggregator = GridSemanticAggregator(**aggregator_config)
        
    def forward(self, 
                images: torch.Tensor, 
                captions: List[str]) -> Tuple[List[torch.Tensor], torch.Tensor]:
        """
        Complete forward pass
        Args:
            images: [B, 3, 224, 224] input images
            captions: List of caption strings for semantic graph construction
        Returns:
            refined_semantic_features: List of refined semantic features
            enhanced_visual_features: [B, M*k, final_output_dim] enhanced visual features
        """
        # Step 1: Extract grid-graph features
        grid_features = self.grid_graph_model(images)  # [B, M*k, grid_feature_dim]
        
        # Step 2: Extract semantic-graph features
        semantic_features_list, semantic_adj_matrices_list = self.semantic_graph_model(captions)
        
        # Step 3: Apply Grid-Semantic Aggregator
        refined_semantic_list, enhanced_visual = self.grid_semantic_aggregator(
            grid_features, semantic_features_list, semantic_adj_matrices_list
        )
        
        return refined_semantic_list, enhanced_visual

# Example usage and testing
def test_grid_semantic_aggregator():
    """Test the complete Grid-Semantic Aggregator implementation"""
    
    print("Testing Grid-Semantic Aggregator Implementation")
    print("=" * 60)
    
    # Import the original models (assuming they're available)
    # from grid_graph_feature import GridGraphFeature
    # from semantic_graph_feature import SemanticGraphFeature
    
    # Create dummy data for testing
    batch_size = 2
    num_patches = 56 * 56  # M*k = 3136 for 224x224 image with patch_size=4
    grid_feature_dim = 128
    semantic_feature_dim = 256
    
    # Create dummy grid features
    dummy_grid_features = torch.randn(batch_size, num_patches, grid_feature_dim)
    
    # Create dummy semantic features (variable length for each sample)
    dummy_semantic_features_list = [
        torch.randn(6, semantic_feature_dim),  # 6 words in first caption
        torch.randn(4, semantic_feature_dim)   # 4 words in second caption
    ]
    
    # Create dummy semantic adjacency matrices
    dummy_semantic_adj_list = [
        torch.eye(6) + torch.randn(6, 6) * 0.1,  # 6x6 adjacency matrix
        torch.eye(4) + torch.randn(4, 4) * 0.1   # 4x4 adjacency matrix
    ]
    
    # Make adjacency matrices symmetric and positive
    for i, adj in enumerate(dummy_semantic_adj_list):
        adj = (adj + adj.T) / 2
        adj = torch.clamp(adj, min=0)
        dummy_semantic_adj_list[i] = adj
    
    # Create aggregator
    aggregator = GridSemanticAggregator(
        grid_feature_dim=grid_feature_dim,
        semantic_feature_dim=semantic_feature_dim,
        lstm_hidden_dim=256,
        mlp_hidden_dim=512,
        final_output_dim=512
    )
    
    print(f"Input shapes:")
    print(f"  Grid features: {dummy_grid_features.shape}")
    print(f"  Semantic features: {[f.shape for f in dummy_semantic_features_list]}")
    print(f"  Semantic adjacency: {[adj.shape for adj in dummy_semantic_adj_list]}")
    
    # Forward pass
    with torch.no_grad():
        refined_semantic_list, enhanced_visual = aggregator(
            dummy_grid_features,
            dummy_semantic_features_list,
            dummy_semantic_adj_list
        )
    
    print(f"\nOutput shapes:")
    print(f"  Refined semantic features: {[f.shape for f in refined_semantic_list]}")
    print(f"  Enhanced visual features: {enhanced_visual.shape}")
    
    # Test individual components
    print(f"\n" + "="*60)
    print("Testing Individual Components")
    print("="*60)
    
    # Test LSTM processing
    lstm_features = aggregator.process_grid_features_with_lstm(dummy_grid_features)
    print(f"LSTM features shape: {lstm_features.shape}")
    
    # Test correlation computation
    single_semantic = dummy_semantic_features_list[0]
    single_grid = lstm_features[0]  # First sample
    correlation_scores = aggregator.compute_correlation_scores(single_semantic, single_grid)
    print(f"Correlation scores shape: {correlation_scores.shape}")
    print(f"Correlation scores range: [{correlation_scores.min():.3f}, {correlation_scores.max():.3f}]")
    
    # Test attention aggregation
    aggregated = aggregator.apply_attention_aggregation(correlation_scores, single_grid)
    print(f"Aggregated features shape: {aggregated.shape}")
    
    print(f"\nTest completed successfully!")

if __name__ == "__main__":
    test_grid_semantic_aggregator()

Testing Grid-Semantic Aggregator Implementation
Input shapes:
  Grid features: torch.Size([2, 3136, 128])
  Semantic features: [torch.Size([6, 256]), torch.Size([4, 256])]
  Semantic adjacency: [torch.Size([6, 6]), torch.Size([4, 4])]

Output shapes:
  Refined semantic features: [torch.Size([6, 512]), torch.Size([4, 512])]
  Enhanced visual features: torch.Size([2, 3136, 512])

Testing Individual Components
LSTM features shape: torch.Size([2, 3136, 512])
Correlation scores shape: torch.Size([6, 3136])
Correlation scores range: [-0.092, -0.002]
Aggregated features shape: torch.Size([6, 512])

Test completed successfully!


**Semantic-semantic aggregator**

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
import math

class SemanticSemanticAggregator(nn.Module):
    """
    Semantic-Semantic Aggregator implementation
    Based on equation (17)-(20) from the paper
    """
    def __init__(self, 
                 input_dim: int = 256,
                 lstm_hidden_dim: int = 512,
                 mlp_hidden_dim: int = 256,
                 output_dim: int = 256):
        super().__init__()
        
        self.input_dim = input_dim
        self.lstm_hidden_dim = lstm_hidden_dim
        self.mlp_hidden_dim = mlp_hidden_dim
        self.output_dim = output_dim
        
        # LSTM for sequential processing of semantic nodes
        # S^t = {s₁ᵗ, s₂ᵗ, s₃ᵗ, ..., sₙᵗ} = LSTM(S^(0))
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=False
        )
        
        # MLPs for attention mechanism
        # g_x, g_z and g_v are MLPs for encoding node features
        self.g_x = nn.Sequential(
            nn.Linear(lstm_hidden_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        )
        
        self.g_z = nn.Sequential(
            nn.Linear(lstm_hidden_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        )
        
        self.g_v = nn.Sequential(
            nn.Linear(lstm_hidden_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim)
        )
        
        # MLP for encoding neighboring features
        # g_n is an MLP to encode the features of neighboring nodes
        self.g_n = nn.Sequential(
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.ReLU(),
            nn.Linear(mlp_hidden_dim, output_dim)
        )
        
        # Final output projection
        self.output_projection = nn.Linear(lstm_hidden_dim + output_dim, output_dim)
        
    def find_neighboring_nodes(self, adjacency_matrix: torch.Tensor, node_idx: int) -> List[int]:
        """
        Find neighboring nodes N_i = {s_j|j ∈ {1, ..., N} and j ≠ i}
        Args:
            adjacency_matrix: [N, N] adjacency matrix
            node_idx: Index of the current node
        Returns:
            List of neighboring node indices
        """
        N = adjacency_matrix.shape[0]
        neighbors = []
        
        for j in range(N):
            if j != node_idx and adjacency_matrix[node_idx, j] > 0:
                neighbors.append(j)
                
        return neighbors
    
    def compute_attention_scores(self, 
                               current_node: torch.Tensor,
                               neighbor_nodes: torch.Tensor) -> torch.Tensor:
        """
        Compute attention scores between current node and its neighbors
        a_{s_i,s_j} = g_x([s_i^(t); g_z(s_j^(t))]) · g_v([s_j^(t); g_z(s_i^(t))])
        
        Args:
            current_node: [lstm_hidden_dim] current node features s_i^(t)
            neighbor_nodes: [K, lstm_hidden_dim] neighbor nodes features
        Returns:
            attention_scores: [K] attention scores
        """
        K = neighbor_nodes.shape[0]
        if K == 0:
            return torch.tensor([], device=current_node.device)
        
        # Expand current node to match neighbor dimensions
        current_expanded = current_node.unsqueeze(0).expand(K, -1)  # [K, lstm_hidden_dim]
        
        # Compute g_z for neighbors and current node
        g_z_neighbors = self.g_z(neighbor_nodes)  # [K, mlp_hidden_dim]
        g_z_current = self.g_z(current_expanded)   # [K, mlp_hidden_dim]
        
        # Concatenate features for attention computation
        # [s_i^(t); g_z(s_j^(t))]
        concat_1 = torch.cat([current_expanded, g_z_neighbors], dim=-1)  # [K, lstm_hidden_dim + mlp_hidden_dim]
        
        # [s_j^(t); g_z(s_i^(t))]  
        concat_2 = torch.cat([neighbor_nodes, g_z_current], dim=-1)  # [K, lstm_hidden_dim + mlp_hidden_dim]
        
        # Apply g_x and g_v
        # Note: We need to adjust dimensions for proper computation
        # Using simplified dot product attention instead of the exact formula
        attention_1 = self.g_x(concat_1[:, :self.lstm_hidden_dim])  # [K, mlp_hidden_dim]
        attention_2 = self.g_v(concat_2[:, :self.lstm_hidden_dim])  # [K, mlp_hidden_dim]
        
        # Compute attention scores (dot product)
        attention_scores = torch.sum(attention_1 * attention_2, dim=-1)  # [K]
        
        return attention_scores
    
    def aggregate_neighbors(self,
                          current_node: torch.Tensor,
                          neighbor_nodes: torch.Tensor,
                          attention_weights: torch.Tensor) -> torch.Tensor:
        """
        Aggregate neighboring node information using attention weights
        Args:
            current_node: [lstm_hidden_dim] current node features
            neighbor_nodes: [K, lstm_hidden_dim] neighbor node features  
            attention_weights: [K] normalized attention weights
        Returns:
            aggregated_features: [output_dim] aggregated neighbor features
        """
        if neighbor_nodes.shape[0] == 0:
            # No neighbors, return zero features
            return torch.zeros(self.output_dim, device=current_node.device)
        
        # Weight neighbors by attention
        weighted_neighbors = attention_weights.unsqueeze(-1) * neighbor_nodes  # [K, lstm_hidden_dim]
        
        # Sum weighted neighbors
        summed_neighbors = torch.sum(weighted_neighbors, dim=0)  # [lstm_hidden_dim]
        
        # Apply g_n MLP to encode neighboring features
        neighbor_encoding = self.g_z(summed_neighbors.unsqueeze(0)).squeeze(0)  # [mlp_hidden_dim]
        aggregated_features = self.g_n(neighbor_encoding.unsqueeze(0)).squeeze(0)  # [output_dim]
        
        return aggregated_features
    
    def forward_single_graph(self, 
                           semantic_features: torch.Tensor,
                           adjacency_matrix: torch.Tensor) -> torch.Tensor:
        """
        Process a single semantic graph
        Args:
            semantic_features: [N, input_dim] semantic node features S^(0)
            adjacency_matrix: [N, N] adjacency matrix
        Returns:
            updated_features: [N, output_dim] updated semantic features S^(3)
        """
        N = semantic_features.shape[0]
        
        # Step 1: Apply LSTM to get sequential representation
        # S^t = {s₁ᵗ, s₂ᵗ, s₃ᵗ, ..., sₙᵗ} = LSTM(S^(0))
        lstm_out, _ = self.lstm(semantic_features.unsqueeze(0))  # [1, N, lstm_hidden_dim]
        lstm_features = lstm_out.squeeze(0)  # [N, lstm_hidden_dim]
        
        # Step 2: For each node, compute attention with neighbors and aggregate
        updated_nodes = []
        
        for i in range(N):
            current_node = lstm_features[i]  # [lstm_hidden_dim]
            
            # Find neighboring nodes
            neighbor_indices = self.find_neighboring_nodes(adjacency_matrix, i)
            
            if len(neighbor_indices) == 0:
                # No neighbors, just use current node
                aggregated_neighbor_features = torch.zeros(self.output_dim, device=current_node.device)
            else:
                # Get neighbor features
                neighbor_nodes = lstm_features[neighbor_indices]  # [K, lstm_hidden_dim]
                
                # Compute attention scores a_{s_i,s_j}
                attention_scores = self.compute_attention_scores(current_node, neighbor_nodes)
                
                # Normalize attention weights using softmax
                attention_weights = F.softmax(attention_scores, dim=0)  # [K]
                
                # Aggregate neighbor information
                aggregated_neighbor_features = self.aggregate_neighbors(
                    current_node, neighbor_nodes, attention_weights
                )
            
            # Step 3: Combine current node with aggregated neighbor features
            # s_i^(3) = [s_i^(t), ∑_{j∈N_i} a_{s_i,s_j} · g_n(s_j^(t))]
            combined_features = torch.cat([current_node, aggregated_neighbor_features], dim=0)
            
            # Apply final projection
            updated_node = self.output_projection(combined_features.unsqueeze(0)).squeeze(0)
            updated_nodes.append(updated_node)
        
        # Stack all updated nodes
        updated_features = torch.stack(updated_nodes, dim=0)  # [N, output_dim]
        
        return updated_features
    
    def forward(self, 
                semantic_features_list: List[torch.Tensor],
                adjacency_matrices_list: List[torch.Tensor]) -> List[torch.Tensor]:
        """
        Process multiple semantic graphs
        Args:
            semantic_features_list: List of [Ni, input_dim] semantic features
            adjacency_matrices_list: List of [Ni, Ni] adjacency matrices
        Returns:
            updated_features_list: List of [Ni, output_dim] updated semantic features
        """
        updated_features_list = []
        
        for semantic_features, adjacency_matrix in zip(semantic_features_list, adjacency_matrices_list):
            updated_features = self.forward_single_graph(semantic_features, adjacency_matrix)
            updated_features_list.append(updated_features)
        
        return updated_features_list

class EnhancedSemanticGraphFeature(nn.Module):
    """
    Enhanced Semantic Graph Feature with Semantic-Semantic Aggregator
    Combines the original SemanticGraphFeature with SemanticSemanticAggregator
    """
    def __init__(self,
                 bert_model_name: str = 'bert-base-uncased',
                 bert_dim: int = 768,
                 gcn_hidden_dim: int = 512,
                 gcn_output_dim: int = 256,
                 lstm_hidden_dim: int = 512,
                 mlp_hidden_dim: int = 256,
                 final_output_dim: int = 256):
        super().__init__()
        
        # Import the original SemanticGraphFeature
        # Note: You would need to import this from your original implementation
        
        self.semantic_graph_feature = SemanticGraphFeature(
            bert_model_name=bert_model_name,
            bert_dim=bert_dim,
            gcn_hidden_dim=gcn_hidden_dim,
            output_dim=gcn_output_dim
        )
        
        # Semantic-Semantic Aggregator
        self.semantic_aggregator = SemanticSemanticAggregator(
            input_dim=gcn_output_dim,
            lstm_hidden_dim=lstm_hidden_dim,
            mlp_hidden_dim=mlp_hidden_dim,
            output_dim=final_output_dim
        )
    
    def forward(self, text_list: List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        """
        Complete semantic processing pipeline
        Args:
            text_list: List of text strings (image captions)
        Returns:
            final_semantic_features: List of [Ni, final_output_dim] final semantic features
            adjacency_matrices: List of [Ni, Ni] adjacency matrices
        """
        # Step 1: Extract initial semantic graph features S^(0)
        initial_features_list, adjacency_matrices_list = self.semantic_graph_feature(text_list)
        
        # Step 2: Apply semantic-semantic aggregator to get S^(3)
        final_features_list = self.semantic_aggregator(initial_features_list, adjacency_matrices_list)
        
        return final_features_list, adjacency_matrices_list

# Test function
def test_semantic_semantic_aggregator():
    """Test the Semantic-Semantic Aggregator"""
    
    print("Testing Semantic-Semantic Aggregator")
    print("=" * 50)
    
    # Create test data
    batch_size = 2
    test_features_list = [
        torch.randn(5, 256),  # 5 nodes, 256-dim features
        torch.randn(7, 256),  # 7 nodes, 256-dim features  
    ]
    
    test_adjacency_list = [
        torch.randint(0, 2, (5, 5)).float(),  # 5x5 adjacency matrix
        torch.randint(0, 2, (7, 7)).float(),  # 7x7 adjacency matrix
    ]
    
    # Ensure diagonal is 1 (self-connections)
    for adj in test_adjacency_list:
        adj.fill_diagonal_(1.0)
    
    # Create aggregator
    aggregator = SemanticSemanticAggregator(
        input_dim=256,
        lstm_hidden_dim=512,
        mlp_hidden_dim=256,
        output_dim=256
    )
    
    print(f"Input shapes:")
    for i, (features, adj) in enumerate(zip(test_features_list, test_adjacency_list)):
        print(f"  Graph {i+1}: Features {features.shape}, Adjacency {adj.shape}")
    
    # Forward pass
    with torch.no_grad():
        output_features_list = aggregator(test_features_list, test_adjacency_list)
    
    print(f"\nOutput shapes:")
    for i, features in enumerate(output_features_list):
        print(f"  Graph {i+1}: {features.shape}")
    
    print(f"\nAggregator parameters:")
    total_params = sum(p.numel() for p in aggregator.parameters())
    print(f"  Total parameters: {total_params:,}")
    
    # Test single graph processing
    print(f"\nTesting single graph processing:")
    single_features = test_features_list[0]
    single_adj = test_adjacency_list[0]
    
    with torch.no_grad():
        single_output = aggregator.forward_single_graph(single_features, single_adj)
    
    print(f"  Input: {single_features.shape}")
    print(f"  Output: {single_output.shape}")
    print(f"  Feature transformation: {single_features.shape[-1]} -> {single_output.shape[-1]}")

if __name__ == "__main__":
    test_semantic_semantic_aggregator()

Testing Semantic-Semantic Aggregator
Input shapes:
  Graph 1: Features torch.Size([5, 256]), Adjacency torch.Size([5, 5])
  Graph 2: Features torch.Size([7, 256]), Adjacency torch.Size([7, 7])

Output shapes:
  Graph 1: torch.Size([5, 256])
  Graph 2: torch.Size([7, 256])

Aggregator parameters:
  Total parameters: 2,496,768

Testing single graph processing:
  Input: torch.Size([5, 256])
  Output: torch.Size([5, 256])
  Feature transformation: 256 -> 256


**Decoder**

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear transformations
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        attention_weights = F.softmax(scores, dim=-1)
        attention_output = torch.matmul(attention_weights, V)
        
        # Concatenate heads
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model)
        
        return self.W_o(attention_output)

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.dropout(0.1)
        
    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.cross_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x, memory, tgt_mask=None, memory_mask=None):
        # Self-attention block
        attn_output = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Cross-attention block  
        attn_output = self.cross_attention(x, memory, memory, memory_mask)
        x = self.norm2(x + self.dropout(attn_output))
        
        # Feed-forward block
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        
        return x

class MMGATImageCaptioning(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6, 
                 d_ff=2048, max_seq_length=50, feature_dim=2048):
        super(MMGATImageCaptioning, self).__init__()
        
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.max_seq_length = max_seq_length
        
        # Feature projection layers
        self.feature_projection = nn.Linear(feature_dim, d_model)
        
        # Word embedding and positional encoding
        self.word_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = self.create_positional_encoding(max_seq_length, d_model)
        
        # Transformer decoder layers
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff) 
            for _ in range(num_layers)
        ])
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        self.dropout = nn.Dropout(0.1)
        
    def create_positional_encoding(self, max_seq_length, d_model):
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        return pe.unsqueeze(0)
    
    def create_causal_mask(self, seq_length):
        """Create causal mask for self-attention"""
        mask = torch.tril(torch.ones(seq_length, seq_length))
        return mask.unsqueeze(0).unsqueeze(0)
    
    def forward(self, combined_features, target_sequence=None, max_length=20):
        """
        Args:
            combined_features: Combined features from MMGAT [batch_size, num_regions, feature_dim]
            target_sequence: Target sequence for training [batch_size, seq_length]
            max_length: Maximum generation length for inference
        """
        batch_size = combined_features.size(0)
        
        # Project image features to model dimension
        memory = self.feature_projection(combined_features)  # [batch_size, num_regions, d_model]
        
        if target_sequence is not None:
            # Training mode
            return self.forward_train(memory, target_sequence)
        else:
            # Inference mode
            return self.forward_inference(memory, max_length)
    
    def forward_train(self, memory, target_sequence):
        batch_size, seq_length = target_sequence.size()
        
        # Word embeddings + positional encoding
        word_embeds = self.word_embedding(target_sequence) * math.sqrt(self.d_model)
        pos_encoding = self.positional_encoding[:, :seq_length, :].to(target_sequence.device)
        decoder_input = self.dropout(word_embeds + pos_encoding)
        
        # Create causal mask
        causal_mask = self.create_causal_mask(seq_length).to(target_sequence.device)
        
        # Pass through decoder layers
        for layer in self.decoder_layers:
            decoder_input = layer(decoder_input, memory, tgt_mask=causal_mask)
        
        # Output projection
        logits = self.output_projection(decoder_input)
        
        return logits
    
    def forward_inference(self, memory, max_length):
        batch_size = memory.size(0)
        device = memory.device
        
        # Start with <BOS> token (assuming index 1)
        generated_sequence = torch.ones(batch_size, 1, dtype=torch.long, device=device)
        
        for step in range(max_length):
            # Word embeddings + positional encoding
            word_embeds = self.word_embedding(generated_sequence) * math.sqrt(self.d_model)
            seq_length = generated_sequence.size(1)
            pos_encoding = self.positional_encoding[:, :seq_length, :].to(device)
            decoder_input = self.dropout(word_embeds + pos_encoding)
            
            # Create causal mask
            causal_mask = self.create_causal_mask(seq_length).to(device)
            
            # Pass through decoder layers
            for layer in self.decoder_layers:
                decoder_input = layer(decoder_input, memory, tgt_mask=causal_mask)
            
            # Get prediction for next token
            logits = self.output_projection(decoder_input[:, -1:, :])
            next_token = torch.argmax(logits, dim=-1)
            
            # Append to sequence
            generated_sequence = torch.cat([generated_sequence, next_token], dim=1)
            
            # Check for <EOS> token (assuming index 2)
            if (next_token == 2).all():
                break
        
        return generated_sequence

# Example usage
def example_usage():
    # Model parameters
    vocab_size = 10000
    d_model = 512
    num_heads = 8
    num_layers = 6
    d_ff = 2048
    feature_dim = 2048  # Combined feature dimension from MMGAT
    
    # Initialize model
    model = MMGATImageCaptioning(
        vocab_size=vocab_size,
        d_model=d_model,
        num_heads=num_heads,
        num_layers=num_layers,
        d_ff=d_ff,
        feature_dim=feature_dim
    )
    
    # Example input - combined features from MMGAT
    batch_size = 2
    num_regions = 49  # e.g., 7x7 grid regions
    combined_features = torch.randn(batch_size, num_regions, feature_dim)
    
    # Training example
    target_sequence = torch.randint(0, vocab_size, (batch_size, 20))
    train_logits = model(combined_features, target_sequence)
    print(f"Training logits shape: {train_logits.shape}")
    
    # Inference example
    model.eval()
    with torch.no_grad():
        generated_captions = model(combined_features, max_length=20)
        print(f"Generated captions shape: {generated_captions.shape}")
    
    return model

# Training function
def train_step(model, combined_features, target_sequence, criterion, optimizer):
    model.train()
    
    # Forward pass
    logits = model(combined_features, target_sequence[:, :-1])  # Exclude last token
    
    # Calculate loss
    loss = criterion(
        logits.reshape(-1, logits.size(-1)), 
        target_sequence[:, 1:].reshape(-1)  # Exclude first token
    )
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()

if __name__ == "__main__":
    model = example_usage()

AttributeError: module 'torch.nn' has no attribute 'dropout'

**Regon-grid Aggregator**

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Tuple, List, Optional
import torch_geometric
from torch_geometric.nn import GCNConv, GATConv, TransformerConv
from torch_geometric.data import Data, Batch

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Tuple, List, Optional
import torch_geometric
from torch_geometric.nn import GCNConv, GATConv, TransformerConv
from torch_geometric.data import Data, Batch

class RegionGridAggregator(nn.Module):
    # ... (phần __init__, _init_parameters, compute_correlation_weights, aggregate_region_features giữ nguyên) ...
    def __init__(self,
                 grid_feature_dim: int = 128,
                 region_feature_dim: int = 512,
                 hidden_dim: int = 256,
                 num_cluster_centers: int = 10,
                 gnn_type: str = 'transformer',  # 'gcn', 'gat', 'transformer'
                 num_gnn_layers: int = 3,
                 num_heads: int = 8,
                 dropout: float = 0.1):
        super().__init__()
        
        self.grid_feature_dim = grid_feature_dim
        self.region_feature_dim = region_feature_dim
        self.hidden_dim = hidden_dim
        self.num_cluster_centers = num_cluster_centers
        self.gnn_type = gnn_type
        self.num_gnn_layers = num_gnn_layers
        
        # Feature projection layers
        self.grid_proj = nn.Linear(grid_feature_dim, hidden_dim)
        self.region_proj = nn.Linear(region_feature_dim, hidden_dim)
        
        # Learnable cluster centers (regional centers r_i^(0))
        self.cluster_centers = nn.Parameter(torch.randn(num_cluster_centers, hidden_dim))
        
        # Correlation computation parameters (b_i and b_j in equation 9)
        self.grid_bias = nn.Parameter(torch.zeros(1))
        self.region_bias = nn.Parameter(torch.zeros(1))
        
        # Learnable parameter r̄ (magnitude equal to r_i^(0))
        self.learnable_r_bar = nn.Parameter(torch.randn(hidden_dim))
        
        # GNN layers for graph-based feature aggregation
        self.gnn_layers = nn.ModuleList()
        
        if gnn_type == 'gcn':
            for i in range(num_gnn_layers):
                self.gnn_layers.append(GCNConv(hidden_dim, hidden_dim))
        elif gnn_type == 'gat':
            for i in range(num_gnn_layers):
                self.gnn_layers.append(GATConv(hidden_dim, hidden_dim // num_heads, 
                                             heads=num_heads, dropout=dropout, concat=True))
        elif gnn_type == 'transformer':
            for i in range(num_gnn_layers):
                self.gnn_layers.append(TransformerConv(hidden_dim, hidden_dim // num_heads,
                                                     heads=num_heads, dropout=dropout, concat=True))
        
        # Layer normalization and dropout
        self.layer_norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_gnn_layers)])
        self.dropout = nn.Dropout(dropout)
        
        # Final output projection
        self.output_proj = nn.Linear(hidden_dim, hidden_dim)
        
        # Initialize parameters
        self._init_parameters()
    
    def _init_parameters(self):
        """Initialize model parameters"""
        # Initialize cluster centers with Xavier uniform
        nn.init.xavier_uniform_(self.cluster_centers)
        nn.init.xavier_uniform_(self.learnable_r_bar.unsqueeze(0))
        
        # Initialize bias parameters
        nn.init.zeros_(self.grid_bias)
        nn.init.zeros_(self.region_bias)
    
    def compute_correlation_weights(self, 
                                  grid_features: torch.Tensor, 
                                  region_centers: torch.Tensor) -> torch.Tensor:
        """
        Compute correlation weights r_{j,i}^(0) between grid features and region centers
        Following equation (9) in the paper
        
        Args:
            grid_features: [B, M, hidden_dim] projected grid features v_j^(0)
            region_centers: [B, K, hidden_dim] region centers r_i^(0)
        Returns:
            correlation_weights: [B, M, K] correlation weights r_{j,i}^(0)
        """
        B, M, D = grid_features.shape
        K = region_centers.shape[1]
        
        # Expand dimensions for broadcasting
        grid_expanded = grid_features.unsqueeze(2)  # [B, M, 1, D]
        centers_expanded = region_centers.unsqueeze(1)  # [B, 1, K, D]
        
        # Compute dot product v_j^(0) · r_i^(0)
        dot_products = torch.sum(grid_expanded * centers_expanded, dim=-1)  # [B, M, K]
        
        # Add bias terms (b_i and b_j)
        dot_products = dot_products + self.grid_bias + self.region_bias
        
        # Apply softmax to get correlation weights (equation 9)
        correlation_weights = F.softmax(dot_products, dim=-1)  # [B, M, K]
        
        return correlation_weights
    
    def aggregate_region_features(self, 
                                grid_features: torch.Tensor,
                                correlation_weights: torch.Tensor) -> torch.Tensor:
        """
        Aggregate grid features into region features using correlation weights
        Following equation (10) in the paper
        
        Args:
            grid_features: [B, M, hidden_dim] projected grid features
            correlation_weights: [B, M, K] correlation weights
        Returns:
            aggregated_regions: [B, K, hidden_dim] aggregated region features r_i^(1)
        """
        B, M, D = grid_features.shape
        K = correlation_weights.shape[-1]
        
        # Weighted sum: Σ(r_{j,i}^(0) * (v_j^(0) - r̄))
        r_bar_expanded = self.learnable_r_bar.unsqueeze(0).unsqueeze(0).expand(B, M, -1)
        grid_centered = grid_features - r_bar_expanded  # v_j^(0) - r̄
        
        # Compute weighted sum
        weighted_features = correlation_weights.unsqueeze(-1) * grid_centered.unsqueeze(2)  # [B, M, K, D]
        aggregated_features = torch.sum(weighted_features, dim=1)  # [B, K, D]
        
        # Apply L2 normalization (Norm in equation 10)
        aggregated_regions = F.normalize(aggregated_features, p=2, dim=-1)
        
        return aggregated_regions

    def build_graph(self,
                   grid_features: torch.Tensor,
                   region_features: torch.Tensor,
                   aggregated_regions: torch.Tensor) -> List[Data]:
        """
        Build graph structure for GNN processing
        """
        B, M, D = grid_features.shape
        P = region_features.shape[1]
        K = aggregated_regions.shape[1]
        # Lấy device từ một tensor đã có sẵn
        device = grid_features.device
        
        graph_list = []
        
        for b in range(B):
            node_features = torch.cat([
                grid_features[b],
                region_features[b],
                aggregated_regions[b]
            ], dim=0)
            
            edge_indices = []
            
            # 1. Grid-to-Grid connections
            grid_edges = self._build_grid_edges(M, device)
            edge_indices.append(grid_edges)
            
            # 2. Region-to-Region connections
            region_edges = self._build_region_edges(region_features[b], M, P)
            edge_indices.append(region_edges)
            
            # 3. Grid-to-Aggregated connections
            grid_agg_edges = self._build_grid_aggregated_edges(
                grid_features[b], aggregated_regions[b], M, P, K
            )
            edge_indices.append(grid_agg_edges)
            
            # 4. Region-to-Aggregated connections
            region_agg_edges = self._build_region_aggregated_edges(
                region_features[b], aggregated_regions[b], M, P, K
            )
            edge_indices.append(region_agg_edges)
            
            # Combine all edges
            all_edge_indices = torch.cat(edge_indices, dim=1)
            
            # edge_attr không được sử dụng trong GCNConv, GATConv, TransformerConv mặc định
            # nên có thể bỏ qua để đơn giản hóa
            graph_data = Data(
                x=node_features,
                edge_index=all_edge_indices
            )
            
            graph_list.append(graph_data)
        
        return graph_list
    
    # SỬA LỖI: Thêm `device` vào các hàm _build
    
    def _build_grid_edges(self, M: int, device: torch.device) -> torch.Tensor:
        """Build spatial connections between grid nodes"""
        grid_size = int(math.sqrt(M))
        edge_indices = []
        
        for i in range(grid_size):
            for j in range(grid_size):
                current_idx = i * grid_size + j
                if j < grid_size - 1:
                    right_idx = i * grid_size + (j + 1)
                    edge_indices.extend([[current_idx, right_idx], [right_idx, current_idx]])
                if i < grid_size - 1:
                    bottom_idx = (i + 1) * grid_size + j
                    edge_indices.extend([[current_idx, bottom_idx], [bottom_idx, current_idx]])
        
        if len(edge_indices) > 0:
            # SỬA LỖI: Thêm device vào đây
            return torch.tensor(edge_indices, dtype=torch.long, device=device).t().contiguous()
        else:
            return torch.empty((2, 0), dtype=torch.long, device=device)

    def _build_region_edges(self, region_features: torch.Tensor, M: int, P: int) -> torch.Tensor:
        """Build semantic connections between region nodes"""
        device = region_features.device # Lấy device
        
        region_norm = F.normalize(region_features, p=2, dim=-1)
        similarity_matrix = torch.mm(region_norm, region_norm.t())
        
        k = min(5, P - 1 if P > 1 else 0)
        if k == 0:
            return torch.empty((2, 0), dtype=torch.long, device=device)
            
        edge_indices = []
        
        for i in range(P):
            similarities = similarity_matrix[i]
            similarities[i] = -1
            _, top_indices = torch.topk(similarities, k)
            
            for j in top_indices:
                if similarities[j] > 0.5:
                    edge_indices.extend([[M + i, M + j], [M + j, M + i]])
        
        if len(edge_indices) > 0:
            # SỬA LỖI: Thêm device vào đây
            return torch.tensor(edge_indices, dtype=torch.long, device=device).t().contiguous()
        else:
            return torch.empty((2, 0), dtype=torch.long, device=device)

    def _build_grid_aggregated_edges(self, 
                                   grid_features: torch.Tensor,
                                   aggregated_regions: torch.Tensor,
                                   M: int, P: int, K: int) -> torch.Tensor:
        """Build correlation-based connections between grid and aggregated nodes"""
        device = grid_features.device # Lấy device
        
        correlation_weights = self.compute_correlation_weights(
            grid_features.unsqueeze(0), aggregated_regions.unsqueeze(0)
        ).squeeze(0)
        
        edge_indices = []
        top_k = min(3, K)
        if top_k == 0:
            return torch.empty((2, 0), dtype=torch.long, device=device)

        for i in range(M):
            _, top_indices = torch.topk(correlation_weights[i], top_k)
            for j in top_indices:
                if correlation_weights[i, j].item() > 0.1:
                    edge_indices.extend([[i, M + P + j], [M + P + j, i]])
        
        if len(edge_indices) > 0:
            # SỬA LỖI: Thêm device vào đây
            return torch.tensor(edge_indices, dtype=torch.long, device=device).t().contiguous()
        else:
            return torch.empty((2, 0), dtype=torch.long, device=device)

    def _build_region_aggregated_edges(self,
                                     region_features: torch.Tensor,
                                     aggregated_regions: torch.Tensor,
                                     M: int, P: int, K: int) -> torch.Tensor:
        """Build attention-based connections between region and aggregated nodes"""
        device = region_features.device # Lấy device

        region_norm = F.normalize(region_features, p=2, dim=-1)
        aggregated_norm = F.normalize(aggregated_regions, p=2, dim=-1)
        attention_scores = torch.mm(region_norm, aggregated_norm.t())
        
        edge_indices = []
        top_k = min(2, K)
        if top_k == 0:
             return torch.empty((2, 0), dtype=torch.long, device=device)

        for i in range(P):
            _, top_indices = torch.topk(attention_scores[i], top_k)
            for j in top_indices:
                if attention_scores[i, j].item() > 0.3:
                    edge_indices.extend([[M + i, M + P + j], [M + P + j, M + i]])
        
        if len(edge_indices) > 0:
            # SỬA LỖI: Thêm device vào đây
            return torch.tensor(edge_indices, dtype=torch.long, device=device).t().contiguous()
        else:
            return torch.empty((2, 0), dtype=torch.long, device=device)

    def forward(self, 
               grid_features: torch.Tensor,
               region_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        B, M, _ = grid_features.shape
        P = region_features.shape[1]
        K = self.num_cluster_centers
        
        grid_projected = self.grid_proj(grid_features)
        region_projected = self.region_proj(region_features)
        
        cluster_centers_batch = self.cluster_centers.unsqueeze(0).expand(B, -1, -1)
        
        correlation_weights = self.compute_correlation_weights(grid_projected, cluster_centers_batch)
        aggregated_regions = self.aggregate_region_features(grid_projected, correlation_weights)
        
        graph_list = self.build_graph(grid_projected, region_projected, aggregated_regions)
        graph_batch = Batch.from_data_list(graph_list)
        
        x, edge_index = graph_batch.x, graph_batch.edge_index
        
        for i, gnn_layer in enumerate(self.gnn_layers):
            if self.gnn_type == 'gcn':
                x_new = gnn_layer(x, edge_index)
            else:
                x_new = gnn_layer(x, edge_index)
            
            x = self.layer_norms[i](x + x_new)
            x = self.dropout(x)
        
        x = self.output_proj(x)
        
        node_counts = [M + P + K] * B
        feature_list = x.split(node_counts)
        
        enhanced_grid_features_list = []
        enhanced_region_features_list = []
        enhanced_aggregated_regions_list = []
        
        for features in feature_list:
            enhanced_grid_features_list.append(features[:M])
            enhanced_region_features_list.append(features[M:M+P])
            enhanced_aggregated_regions_list.append(features[M+P:])
        
        enhanced_grid_features = torch.stack(enhanced_grid_features_list)
        enhanced_region_features = torch.stack(enhanced_region_features_list)
        enhanced_aggregated_regions = torch.stack(enhanced_aggregated_regions_list)
        
        return enhanced_grid_features, enhanced_region_features, enhanced_aggregated_regions
# Complete integration example
class MMGATWithRegionGridAggregator(nn.Module):
    """Complete MMGAT model with Region-Grid Aggregator"""
    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 4,
                 grid_embed_dim: int = 96,
                 region_feature_dim: int = 512,
                 hidden_dim: int = 256,
                 num_regions: int = 50,
                 num_cluster_centers: int = 10,
                 gnn_type: str = 'transformer'):
        super().__init__()
        
        # Import the feature extractors (assuming they're available)
        # from your_module import GridGraphFeature, RegionGraphFeature
        
        # Feature extractors
        # self.grid_extractor = GridGraphFeature(
        #     image_size=image_size,
        #     patch_size=patch_size,
        #     embed_dim=grid_embed_dim,
        #     gcn_hidden_dim=128
        # )
        
        # self.region_extractor = RegionGraphFeature(
        #     num_regions=num_regions,
        #     region_feature_dim=region_feature_dim,
        #     output_dim=region_feature_dim
        # )
        
        # Region-Grid Aggregator
        self.region_grid_aggregator = RegionGridAggregator(
            grid_feature_dim=128,  # Output from GridGraphFeature
            region_feature_dim=region_feature_dim,
            hidden_dim=hidden_dim,
            num_cluster_centers=num_cluster_centers,
            gnn_type=gnn_type
        )
        
        # Final fusion layer
        self.final_fusion = nn.Linear(hidden_dim * 3, hidden_dim)  # Grid + Region + Aggregated
    
    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Complete forward pass
        Args:
            images: [B, 3, H, W] input images
        Returns:
            fused_features: [B, N_total, hidden_dim] final fused features
        """
        # Extract features (commented out since modules not imported)
        # grid_features = self.grid_extractor(images)      # [B, M, 128]
        # region_features = self.region_extractor(images)  # [B, P, 512]
        
        # For demonstration, create dummy features
        B = images.shape[0]
        M = (224 // 4) ** 2  # Grid patches
        P = 50  # Number of regions
        
        grid_features = torch.randn(B, M, 128, device=images.device)
        region_features = torch.randn(B, P, 512, device=images.device)
        
        # Apply Region-Grid Aggregator
        enhanced_grid, enhanced_region, aggregated_region = self.region_grid_aggregator(
            grid_features, region_features
        )
        
        # Combine all features
        combined_features = torch.cat([
            enhanced_grid,      # [B, M, hidden_dim]
            enhanced_region,    # [B, P, hidden_dim]
            aggregated_region   # [B, K, hidden_dim]
        ], dim=1)  # [B, M+P+K, hidden_dim]
        
        return combined_features

# Test function
def test_region_grid_aggregator():
    """Test the Region-Grid Aggregator implementation"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create model
    model = RegionGridAggregator(
        grid_feature_dim=128,
        region_feature_dim=512,
        hidden_dim=256,
        num_cluster_centers=10,
        gnn_type='transformer',
        num_gnn_layers=3
    ).to(device)
    
    # Create dummy input features
    batch_size = 2
    M = 56 * 56  # Grid features (56x56 grid)
    P = 50       # Region features
    
    grid_features = torch.randn(batch_size, M, 128).to(device)
    region_features = torch.randn(batch_size, P, 512).to(device)
    
    print(f"Input grid features shape: {grid_features.shape}")
    print(f"Input region features shape: {region_features.shape}")
    
    # Forward pass
    with torch.no_grad():
        enhanced_grid, enhanced_region, aggregated_region = model(grid_features, region_features)
    
    print(f"Enhanced grid features shape: {enhanced_grid.shape}")
    print(f"Enhanced region features shape: {enhanced_region.shape}")
    print(f"Aggregated region features shape: {aggregated_region.shape}")
    
    # Test correlation weights computation
    grid_proj = model.grid_proj(grid_features)
    cluster_centers_batch = model.cluster_centers.unsqueeze(0).expand(batch_size, -1, -1)
    correlation_weights = model.compute_correlation_weights(grid_proj, cluster_centers_batch)
    print(f"Correlation weights shape: {correlation_weights.shape}")
    print(f"Correlation weights sum (should be ~1.0): {correlation_weights.sum(dim=-1).mean():.3f}")
    
    if device.type == 'cuda':
        print(f"GPU memory allocated: {torch.cuda.memory_allocated(device) / 1024**2:.2f} MB")

if __name__ == "__main__":
    test_region_grid_aggregator()

Using device: cuda
Input grid features shape: torch.Size([2, 3136, 128])
Input region features shape: torch.Size([2, 50, 512])
Enhanced grid features shape: torch.Size([2, 3136, 256])
Enhanced region features shape: torch.Size([2, 50, 256])
Aggregated region features shape: torch.Size([2, 10, 256])
Correlation weights shape: torch.Size([2, 3136, 10])
Correlation weights sum (should be ~1.0): 1.000
GPU memory allocated: 74.36 MB


**Grid-semantic aggregator**