In [2]:
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# Your existing function
def numpy_adj_to_torch_sparse_tensor(adj_matrix):
    """
    checked!
    For data in .mat
    """
    rows, cols = np.nonzero(adj_matrix)
    indices = np.stack((rows, cols), axis=0)
    indices = torch.from_numpy(indices.astype(np.int64))
    num_edges = indices.shape[1]
    values = torch.ones(num_edges, dtype=torch.float32)
    
    shape = torch.Size(adj_matrix.shape)
    sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
    
    return sparse_tensor

# FIXED: Function to extract edge_index from sparse tensor
def get_edge_index_from_sparse(sparse_tensor):
    """
    Extract edge_index from a sparse COO tensor
    """
    # Need to coalesce the tensor first!
    coalesced_tensor = sparse_tensor.coalesce()
    return coalesced_tensor.indices()

# Create a test adjacency matrix
def create_test_adj_matrix():
    """
    Create a simple 5x5 adjacency matrix for testing
    """
    adj = np.array([
        [0, 1, 1, 0, 0],
        [1, 0, 1, 1, 0],
        [1, 1, 0, 1, 1],
        [0, 1, 1, 0, 1],
        [0, 0, 1, 1, 0]
    ])
    return adj

# Test the conversion
def test_conversion():
    print("=== Testing Option 2: Extract edge_index from sparse tensor ===\n")
    
    # Step 1: Create test adjacency matrix
    adj_matrix = create_test_adj_matrix()
    print("Original adjacency matrix:")
    print(adj_matrix)
    print(f"Shape: {adj_matrix.shape}")
    print()
    
    # Step 2: Convert to sparse tensor using your existing function
    sparse_tensor = numpy_adj_to_torch_sparse_tensor(adj_matrix)
    print("Sparse tensor (before coalescing):")
    print(f"Is coalesced: {sparse_tensor.is_coalesced()}")
    print(sparse_tensor)
    print()
    
    # Step 3: Extract edge_index (this will coalesce internally)
    edge_index = get_edge_index_from_sparse(sparse_tensor)
    print("Extracted edge_index:")
    print(edge_index)
    print(f"Shape: {edge_index.shape}")
    print(f"Number of edges: {edge_index.shape[1]}")
    print()
    
    # Step 4: Verify edge_index is correct
    print("Edge connections (source -> target):")
    for i in range(edge_index.shape[1]):
        src, tgt = edge_index[0, i].item(), edge_index[1, i].item()
        print(f"  {src} -> {tgt}")
    print()
    
    return edge_index

# Test with GCNConv
def test_with_gcnconv(edge_index):
    print("=== Testing with GCNConv ===\n")
    
    # Create sample node features (5 nodes, 3 features each)
    num_nodes = 5
    input_dim = 3
    output_dim = 2
    
    x = torch.randn(num_nodes, input_dim)
    print("Node features:")
    print(x)
    print(f"Shape: {x.shape}")
    print()
    
    # Create GCN layer
    gcn_layer = GCNConv(input_dim, output_dim)
    
    # Forward pass
    try:
        output = gcn_layer(x, edge_index)
        print("GCNConv output:")
        print(output)
        print(f"Shape: {output.shape}")
        print("\n✅ Success! The edge_index works correctly with GCNConv")
        
        return output
    except Exception as e:
        print(f"❌ Error: {e}")
        return None

# Alternative: You can also modify your original function to return coalesced tensor
def numpy_adj_to_torch_sparse_tensor_coalesced(adj_matrix):
    """
    Same as your original function but returns a coalesced tensor
    """
    rows, cols = np.nonzero(adj_matrix)
    indices = np.stack((rows, cols), axis=0)
    indices = torch.from_numpy(indices.astype(np.int64))
    num_edges = indices.shape[1]
    values = torch.ones(num_edges, dtype=torch.float32)
    
    shape = torch.Size(adj_matrix.shape)
    sparse_tensor = torch.sparse_coo_tensor(indices, values, shape, dtype=torch.float32)
    
    # Coalesce the tensor before returning
    return sparse_tensor.coalesce()

# Run the complete test
if __name__ == "__main__":
    # Test the conversion
    edge_index = test_conversion()
    
    # Test with GCNConv
    output = test_with_gcnconv(edge_index)
    
    # Additional verification: reconstruct adjacency matrix
    print("\n=== Verification: Reconstruct adjacency matrix ===")
    num_nodes = 5
    reconstructed_adj = torch.zeros(num_nodes, num_nodes)
    
    for i in range(edge_index.shape[1]):
        src, tgt = edge_index[0, i].item(), edge_index[1, i].item()
        reconstructed_adj[src, tgt] = 1
    
    print("Reconstructed adjacency matrix:")
    print(reconstructed_adj.numpy().astype(int))
    
    # Check if original and reconstructed match
    original_adj = create_test_adj_matrix()
    matches = np.array_equal(original_adj, reconstructed_adj.numpy())
    print(f"\nOriginal and reconstructed matrices match: {'✅' if matches else '❌'}")
    
    print("\n=== Testing alternative coalesced function ===")
    # Test the alternative function
    sparse_tensor_coalesced = numpy_adj_to_torch_sparse_tensor_coalesced(create_test_adj_matrix())
    print(f"Is coalesced: {sparse_tensor_coalesced.is_coalesced()}")
    edge_index_alt = sparse_tensor_coalesced.indices()
    print("Edge index from coalesced tensor:")
    print(edge_index_alt)

=== Testing Option 2: Extract edge_index from sparse tensor ===

Original adjacency matrix:
[[0 1 1 0 0]
 [1 0 1 1 0]
 [1 1 0 1 1]
 [0 1 1 0 1]
 [0 0 1 1 0]]
Shape: (5, 5)

Sparse tensor (before coalescing):
Is coalesced: False
tensor(indices=tensor([[0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4],
                       [1, 2, 0, 2, 3, 0, 1, 3, 4, 1, 2, 4, 2, 3]]),
       values=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
       size=(5, 5), nnz=14, layout=torch.sparse_coo)

Extracted edge_index:
tensor([[0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4],
        [1, 2, 0, 2, 3, 0, 1, 3, 4, 1, 2, 4, 2, 3]])
Shape: torch.Size([2, 14])
Number of edges: 14

Edge connections (source -> target):
  0 -> 1
  0 -> 2
  1 -> 0
  1 -> 2
  1 -> 3
  2 -> 0
  2 -> 1
  2 -> 3
  2 -> 4
  3 -> 1
  3 -> 2
  3 -> 4
  4 -> 2
  4 -> 3

=== Testing with GCNConv ===

Node features:
tensor([[-1.1015, -0.7075,  0.9741],
        [-0.1155, -0.8269,  0.0727],
        [-1.2478, -0.5151,  1.1010],
       