In [7]:
!pip install torch torchvision torch-geometric torchaudio scikit-image matplotlib



In [9]:
# Imports
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader  # Fixed import
from torch_geometric.nn import GATConv, global_max_pool  # Added global_max_pool
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

AttributeError: partially initialized module 'torch_geometric' has no attribute 'typing' (most likely due to a circular import)

In [None]:
import cv2
import numpy as np
import torch
from skimage.segmentation import slic
from scipy.spatial import Delaunay
from sklearn.neighbors import kneighbors_graph  # ADDED IMPORT
from torch_geometric.data import Data

def get_keypoints(skeleton):
    corners = cv2.cornerHarris(skeleton.astype(np.uint8), blockSize=2, ksize=3, k=0.04)
    keypoints = np.argwhere(corners > 0.01 * corners.max())
    return keypoints

def get_superpixels(skeleton, n_segments=25):
    segments = slic(skeleton, n_segments=n_segments, compactness=10)
    return segments

def construct_graph(img, method='keypoint'):
    img = img.numpy().squeeze()
    _, binary = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
    skeleton = cv2.ximgproc.thinning(binary.astype(np.uint8))
    
    if method == 'keypoint':
        nodes = get_keypoints(skeleton)
        if len(nodes) < 3:
            # Fallback to dense nodes with k-NN (not Delaunay)
            nodes = np.argwhere(skeleton > 0)
            edge_index = kneighbors_graph(nodes, n_neighbors=8, mode='connectivity').nonzero()
            edge_index = torch.tensor(edge_index, dtype=torch.long)
        else:
            tri = Delaunay(nodes)
            edges = tri.simplices[:, [0,1,1,2,2,0]].reshape(-1, 2).T
            edge_index = torch.tensor(edges, dtype=torch.long)
        
    elif method == 'superpixel':
        segments = get_superpixels(skeleton)
        nodes = []
        for i in np.unique(segments):
            mask = (segments == i).astype(np.uint8)
            if mask.sum() < 5:
                continue
            centroid = np.array([np.mean(np.where(mask)[1]), np.mean(np.where(mask)[0])])
            nodes.append(centroid)
        nodes = np.array(nodes)
        if len(nodes) == 0:
            nodes = np.array([[14, 14]])  # Fallback to center if no superpixels
        n_neighbors = min(4, len(nodes)-1)  # Avoid over-neighboring
        edge_index = kneighbors_graph(nodes, n_neighbors=n_neighbors, mode='connectivity').nonzero()
        edge_index = torch.tensor(edge_index, dtype=torch.long)
        
    elif method == 'dense':
        nodes = np.argwhere(skeleton > 0)
        edge_index = kneighbors_graph(nodes, n_neighbors=8, mode='connectivity').nonzero()
        edge_index = torch.tensor(edge_index, dtype=torch.long)
        
    elif method == 'stroke':
        nodes = np.argwhere(skeleton > 0)
        nodes = nodes[np.argsort(nodes[:, 1])]  # Sort left-to-right
        if len(nodes) < 2:
            nodes = np.argwhere(skeleton > 0)  # Fallback to avoid empty edges
        seq_edges = np.stack([np.arange(len(nodes)-1), np.arange(1, len(nodes))], axis=1)
        spatial_edges = kneighbors_graph(nodes, n_neighbors=8, mode='connectivity').nonzero()
        edges = np.concatenate([seq_edges, spatial_edges], axis=0)
        edge_index = torch.tensor(edges.T, dtype=torch.long)
        
    else:
        raise ValueError("Invalid method")

    # --- KEY FIX: Extract intensities BEFORE normalizing coordinates ---
    original_nodes = nodes.copy()  # Use original (y, x) to index the image
    intensities = img[original_nodes[:, 0].astype(int), original_nodes[:, 1].astype(int)].reshape(-1, 1)
    
    # Normalize coordinates
    nodes = nodes / 27.0  # Now safe to normalize
    
    # --- KEY FIX: Convert edge indices to NumPy for indexing ---
    row = edge_index[0].numpy()
    col = edge_index[1].numpy()
    
    # Compute edge features
    src = nodes[row]
    dst = nodes[col]
    rel_pos = dst - src
    dist = np.linalg.norm(rel_pos, axis=1, keepdims=True)
    angle = np.arctan2(rel_pos[:, 1], rel_pos[:, 0]).reshape(-1, 1)
    edge_attr = torch.tensor(np.hstack([dist, angle]), dtype=torch.float)
    
    # Node features: normalized coordinates + original intensities
    x = torch.tensor(np.hstack([nodes, intensities]), dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)