# Graph Neural Networks for Computer Vision

## 1. Initial Imports

In [1]:
import torch
import torch.nn as nn
import networkx as nx
import numpy as np
import warnings
warnings.filterwarnings('ignore')

from torch_geometric.nn import GCNConv, TopKPooling
from torch.nn import MultiheadAttention
from torch_geometric.data import Data, Batch
import matplotlib.pyplot as plt
from PIL import Image

## 2. Graph Construction Functions

In [47]:
def create_face_graph(landmarks, threshold=2.0):
    G = nx.Graph()
    for i, landmark in enumerate(landmarks):
        G.add_node(i, pos=landmark)
    
    for i in range(len(landmarks)):
        for j in range(i+1, len(landmarks)):
            if np.linalg.norm(landmarks[i] - landmarks[j]) < threshold:
                G.add_edge(i, j)
    return G


def create_pixel_graph(image, connectivity=4):
    height, width = image.shape[:2]
    G = nx.Graph()
    
    for i in range(height):
        for j in range(width):
            node_id = i * width + j
            G.add_node(node_id, features=image[i, j], pos=(i, j))
            
            if connectivity == 4:
                neighbors = [(i-1, j), (i+1, j), (i, j-1), (i, j+1)]
            elif connectivity == 8:
                neighbors = [(i-1, j), (i+1, j), (i, j-1), (i, j+1),
                           (i-1, j-1), (i-1, j+1), (i+1, j-1), (i+1, j+1)]
                
            for ni, nj in neighbors:
                if 0 <= ni < height and 0 <= nj < width:
                    neighbor_id = ni * width + nj
                    G.add_edge(node_id, neighbor_id)
    return G

## 3. Basic GNN Models

### 3.1 Simple GCN for Image Classification

In [48]:
class SimpleGCN(nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(SimpleGCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.conv2 = GCNConv(16, num_classes)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        return x

### 3.2 Hierarchical GCN

In [49]:
class HierarchicalGCN(nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(HierarchicalGCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 64)
        self.pool1 = TopKPooling(64, ratio=0.8)
        self.conv2 = GCNConv(64, 32)
        self.pool2 = TopKPooling(32, ratio=0.8)
        self.conv3 = GCNConv(32, num_classes)
        
    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x = self.conv2(x, edge_index)
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x = self.conv3(x, edge_index)
        return x

## 4. Object Detection and Segmentation Models

### 4.1 Object Proposal GNN

In [50]:
class ObjectProposalGNN(nn.Module):
    def __init__(self, num_node_features):
        super(ObjectProposalGNN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 64)
        self.conv2 = GCNConv(64, 32)
        self.conv3 = GCNConv(32, 1)
        
    def forward(self, x, edge_index, batch):
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x

class InstanceSegmentationGNN(nn.Module):
    def __init__(self, num_features):
        super(InstanceSegmentationGNN, self).__init__()
        self.conv1 = GCNConv(num_features, 64)
        self.conv2 = GCNConv(64, 32)
        self.conv3 = GCNConv(32, 1)
        
    def forward(self, x, edge_index, batch):
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        mask_prob = torch.sigmoid(self.conv3(x, edge_index))
        return mask_prob

## 5. Multimodal GNN Models

### 5.1 Visual-Textual GNN

In [51]:
class VisualTextualGNN(nn.Module):
    def __init__(self, image_feature_dim, word_embedding_dim, hidden_dim):
        super(VisualTextualGNN, self).__init__()
        self.image_encoder = GCNConv(image_feature_dim, hidden_dim)
        self.text_encoder = GCNConv(word_embedding_dim, hidden_dim)
        self.fusion_layer = GCNConv(hidden_dim, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, 1)
        
    def forward(self, image_features, word_embeddings, edge_index):
        image_enc = self.image_encoder(image_features, edge_index)
        text_enc = self.text_encoder(word_embeddings, edge_index)
        fused = self.fusion_layer(image_enc + text_enc, edge_index)
        return self.output_layer(fused)

class CrossModalRetrievalGNN(nn.Module):
    def __init__(self, image_dim, text_dim, hidden_dim):
        super(CrossModalRetrievalGNN, self).__init__()
        self.image_encoder = GCNConv(image_dim, hidden_dim)
        self.text_encoder = GCNConv(text_dim, hidden_dim)
        self.fusion = GCNConv(hidden_dim, hidden_dim)
        
    def forward(self, image_features, text_features, edge_index):
        img_enc = self.image_encoder(image_features, edge_index)
        text_enc = self.text_encoder(text_features, edge_index)
        fused = self.fusion(img_enc + text_enc, edge_index)
        return fused

## 6. Advanced Vision Models

### 6.1 Relational Object Detection

In [52]:
class RelationalObjectDetectionGNN(nn.Module):
    def __init__(self, num_features, num_classes):
        super(RelationalObjectDetectionGNN, self).__init__()
        self.conv1 = GCNConv(num_features, 64)
        self.conv2 = GCNConv(64, 32)
        self.classifier = nn.Linear(32, num_classes)
        self.bbox_regressor = nn.Linear(32, 4)
        
    def forward(self, x, edge_index):
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        class_scores = self.classifier(x)
        bbox_refinement = self.bbox_regressor(x)
        return class_scores, bbox_refinement

class PanopticSegmentationGNN(nn.Module):
    def __init__(self, num_features, num_classes):
        super(PanopticSegmentationGNN, self).__init__()
        self.conv1 = GCNConv(num_features, 64)
        self.conv2 = GCNConv(64, 32)
        self.classifier = nn.Linear(32, num_classes)
        self.instance_predictor = nn.Linear(32, 1)
        
    def forward(self, x, edge_index):
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        semantic_pred = self.classifier(x)
        instance_pred = self.instance_predictor(x)
        return semantic_pred, instance_pred

## 7. Navigation and Hierarchical Models

In [53]:
class VisualLanguageNavigationGNN(nn.Module):
    def __init__(self, visual_dim, instruction_dim, hidden_dim, num_actions=4):
        super(VisualLanguageNavigationGNN, self).__init__()
        self.visual_gnn = GCNConv(visual_dim, hidden_dim)
        self.instruction_gnn = GCNConv(instruction_dim, hidden_dim)
        self.navigation_head = nn.Linear(hidden_dim * 2, num_actions)
        
    def forward(self, visual_obs, instructions, scene_graph, instr_graph):
        visual_feat = self.visual_gnn(visual_obs, scene_graph)
        instr_feat = self.instruction_gnn(instructions, instr_graph)
        combined = torch.cat([visual_feat, instr_feat], dim=-1)
        action_logits = self.navigation_head(combined)
        return action_logits

class HierarchicalImageGNN(nn.Module):
    def __init__(self, input_dim, hidden_dims=[64, 32, 16]):
        super(HierarchicalImageGNN, self).__init__()
        self.levels = len(hidden_dims)
        self.gnns = nn.ModuleList()
        self.pools = nn.ModuleList()
        
        curr_dim = input_dim
        for hidden_dim in hidden_dims:
            self.gnns.append(GCNConv(curr_dim, hidden_dim))
            self.pools.append(TopKPooling(hidden_dim, ratio=0.5))
            curr_dim = hidden_dim
            
    def forward(self, x, edge_index, batch):
        features = []
        for i in range(self.levels):
            x = self.gnns[i](x, edge_index)
            x, edge_index, _, batch, _, _ = self.pools[i](x, edge_index, None, batch)
            features.append(x)
        return features

## 8. Testing and Evaluation

In [54]:
def test_models():
    # Create sample data
    num_nodes = 10
    num_features = 3
    num_classes = 2
    edge_index = torch.randint(0, num_nodes, (2, 20))
    x = torch.randn(num_nodes, num_features)
    batch = torch.zeros(num_nodes, dtype=torch.long)
    
    # Test each model
    models = {
        "SimpleGCN": SimpleGCN(num_features, num_classes),
        "HierarchicalGCN": HierarchicalGCN(num_features, num_classes),
        "ObjectProposalGNN": ObjectProposalGNN(num_features),
        "InstanceSegmentationGNN": InstanceSegmentationGNN(num_features),
        "VisualTextualGNN": VisualTextualGNN(num_features, num_features, 16),
        "RelationalObjectDetectionGNN": RelationalObjectDetectionGNN(num_features, num_classes),
        "PanopticSegmentationGNN": PanopticSegmentationGNN(num_features, num_classes),
        "CrossModalRetrievalGNN": CrossModalRetrievalGNN(num_features, num_features, 16),
        "VisualLanguageNavigationGNN": VisualLanguageNavigationGNN(num_features, num_features, 16),
        "HierarchicalImageGNN": HierarchicalImageGNN(num_features)
    }
    
    print("Testing models...")
    for name, model in models.items():
        try:
            if name in ["VisualTextualGNN", "CrossModalRetrievalGNN"]:
                output = model(x, x, edge_index)
            elif name == "VisualLanguageNavigationGNN":
                output = model(x, x, edge_index, edge_index)
            elif name in ["HierarchicalGCN", "ObjectProposalGNN", "InstanceSegmentationGNN", "HierarchicalImageGNN"]:
                output = model(x, edge_index, batch)
            else:
                output = model(x, edge_index)
            print(f"{name}: Success ✓")
        except Exception as e:
            print(f"{name}: Failed ✗ - {str(e)}")

In [58]:
# Test graph construction
landmarks = np.random.rand(5, 2)
face_graph = create_face_graph(landmarks)
print("\nFace graph created successfully ✓")

# Load and process sample image
sample_image = Image.open('dataset/meeting.jpg')
#sample_image = sample_image.resize((10, 10))
image = np.array(sample_image)
pixel_graph = create_pixel_graph(image)
print("Pixel graph created successfully ✓")

# Test all GNN models
test_models()


Face graph created successfully ✓
Pixel graph created successfully ✓
Testing models...
SimpleGCN: Success ✓
HierarchicalGCN: Success ✓
ObjectProposalGNN: Success ✓
InstanceSegmentationGNN: Success ✓
VisualTextualGNN: Success ✓
RelationalObjectDetectionGNN: Success ✓
PanopticSegmentationGNN: Success ✓
CrossModalRetrievalGNN: Success ✓
VisualLanguageNavigationGNN: Success ✓
HierarchicalImageGNN: Success ✓
