In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from torch_geometric.data import HeteroData
import networkx as nx
from collections import defaultdict
import pypose as pp

In [2]:
all_matches = torch.load('matches.pt')

def covisibility_graph(cams, matches, G=None):
    if G is None:
        G = nx.Graph()

    for cam_idx in cams:
        for i, j in matches[cam_idx]:
            G.add_edge((cam_idx, i), (cam_idx + 1, j))

    TC = nx.transitive_closure(G, reflexive=False)
    components = nx.connected_components(TC)
    points3D = [sublist for sublist in components if len(sublist) >= 3]
    
    cam2point = defaultdict(list)
    for i, point in enumerate(points3D):
            for cam, feature_idx in point:
                cam2point[cam].append((feature_idx, i))

    return points3D, cam2point, TC

covisibility_graph(list(range(9)), all_matches)

([{(0, 216),
   (1, 158),
   (2, 93),
   (3, 133),
   (4, 135),
   (5, 29),
   (6, 12),
   (7, 32),
   (8, 74),
   (9, 147)},
  {(0, 218), (1, 166), (2, 179), (3, 231), (4, 206)},
  {(0, 225), (1, 164), (2, 121), (3, 175)},
  {(0, 230), (1, 257), (2, 63)},
  {(0, 234), (1, 271), (2, 115)},
  {(0, 236), (1, 290), (2, 212)},
  {(0, 237), (1, 171), (2, 104), (3, 125), (4, 122), (5, 26)},
  {(0, 239), (1, 232), (2, 7)},
  {(0, 241), (1, 189), (2, 222), (3, 261), (4, 239), (5, 86), (6, 40)},
  {(0, 248), (1, 273), (2, 88)},
  {(0, 250), (1, 206), (2, 2)},
  {(0, 251), (1, 207), (2, 3)},
  {(0, 253),
   (1, 197),
   (2, 118),
   (3, 140),
   (4, 153),
   (5, 41),
   (6, 26),
   (7, 40),
   (8, 116),
   (9, 201)},
  {(0, 258),
   (1, 199),
   (2, 240),
   (3, 289),
   (4, 271),
   (5, 108),
   (6, 57),
   (7, 62),
   (8, 119),
   (9, 167)},
  {(0, 261), (1, 215), (2, 1)},
  {(0, 264), (1, 303), (2, 205)},
  {(0, 268),
   (1, 201),
   (2, 132),
   (3, 164),
   (4, 166),
   (5, 57),
   (6, 36),

In [3]:
def covisibility_graph(cams, matches, G=None):
    if G is None:
        G = nx.Graph()

    for cam_idx in cams:
        for i, j in matches[cam_idx]:
            G.add_edge((cam_idx, i), (cam_idx + 1, j))

    TC = nx.transitive_closure(G, reflexive=False)
    components = nx.connected_components(TC)
    points3D = [sublist for sublist in components if len(sublist) >= 2]
    
    cam2point = defaultdict(list)
    for i, point in enumerate(points3D):
            for cam, feature_idx in point:
                cam2point[cam].append((feature_idx, i))

    return points3D, cam2point, G

points3D, cam2point, G = covisibility_graph(list(range(9)), all_matches)

In [4]:
len(points3D)

2076

In [7]:
feature_matches = torch.load('features.pt')
all_matches = torch.load('matches.pt')

In [39]:
def get_connected_component_labels(seqential_matches):
    connected_components = torch.zeros((10,2000)) # Labels for points
    point_label = 0
    for i, match in enumerate(seqential_matches):
        for u, v in match:
            label1 = connected_components[i, u]
            label2 = connected_components[i+1, v]
            if label1 == label2: # Should check if they are both zeros
                point_label += 1
                connected_components[i, u] = point_label
                connected_components[i+1, v] = point_label
            else:   # Should use set the zero-valued label with the non-zero label.
                if label2 == 0: 
                    connected_components[i+1, v] = label1
                else:
                    connected_components[i, i] = label2
    return connected_components, point_label

connected_components, num_points = get_connected_component_labels(all_matches)

cams = torch.randn((10, 7))
points3D = torch.randn((num_points, 3))
non_zero = connected_components.nonzero()
camera_indices = non_zero[:,0]
point_indices = connected_components[non_zero[:,0], non_zero[:,1]]
edge_index = torch.vstack([camera_indices, point_indices]).int()
edge_attr = torch.vstack([feature_matches[cam]['keypoints'][0, point] for cam, point in non_zero])

cams.shape, points3D.shape, edge_index.shape, edge_attr.shape

(torch.Size([10, 7]),
 torch.Size([2076, 3]),
 torch.Size([2, 6674]),
 torch.Size([6674, 2]),
 tensor([   1,    2,    3,  ..., 2074, 2072, 2076], dtype=torch.int32))

In [36]:
import torch
from torch_geometric.data import HeteroData

# Initialize HeteroData object
graph = HeteroData()

# Initialize cameras and points with some example features (e.g., pose for cameras, coordinates for points)
graph['camera'].x = cams  # Example: 6D pose (position + orientation) for each camera
graph['point'].x = points3D   # 3D coordinates for each point

# Assuming we have some relationship between cameras and points, represented by edges
edge_index = edge_index  # [2, E] matrix of camera -> point
edge_attr = edge_attr  # 2D pixel coordinates for each camera-point pair

# Add edges to the graph (from 'camera' to 'point' with type 'sees')
graph[('camera', 'sees', 'point')].edge_index = edge_index
graph[('camera', 'sees', 'point')].edge_attr = edge_attr  # 2D points as edge attributes

# Function to retrieve 3D and 2D points seen by a specific camera
def get_camera_data(graph, camera_id):
    # Find edges originating from the given camera
    edges = (graph[('camera', 'sees', 'point')].edge_index[0] == camera_id).nonzero().view(-1)
    
    # Get indices of points seen by the camera
    point_indices = graph[('camera', 'sees', 'point')].edge_index[1][edges]
    
    # Retrieve 3D points and 2D pixel coordinates
    points_3d = graph['point'].x[point_indices]
    points_2d = graph[('camera', 'sees', 'point')].edge_attr[edges]
    
    return points_3d, points_2d

def get_point_data(graph, point_id):
    # Find edges that terminate at the given point
    edges = (graph[('camera', 'sees', 'point')].edge_index[1] == point_id).nonzero().view(-1)
    
    # Get indices of cameras that see the point
    camera_indices = graph[('camera', 'sees', 'point')].edge_index[0][edges]
    
    # Retrieve the poses of the cameras
    camera_poses = graph['camera'].x[camera_indices]
    
    # Retrieve the 2D pixel coordinates associated with the point in each camera
    pixel_coords = graph[('camera', 'sees', 'point')].edge_attr[edges]
    
    return camera_poses, pixel_coords

# Example usage: Get data for point 10
camera_poses, pixel_coords = get_point_data(graph, point_id=10)
print("Poses of cameras that see Point 10:", camera_poses)
print("Corresponding 2D pixel coordinates in those cameras:", pixel_coords)

# Example usage: Get data for camera 1
points_3d, points_2d = get_camera_data(graph, camera_id=1)
print("3D Points seen by Camera 1:", points_3d)
print("Corresponding 2D pixel coordinates:", points_2d)


Poses of cameras that see Point 10: tensor([[-0.3362, -1.2706, -0.9403,  0.1738, -0.9840, -0.6221, -0.9640],
        [-1.3231, -0.7026,  0.3232,  0.1284, -0.7455,  0.1614, -0.4462]])
Corresponding 2D pixel coordinates in those cameras: tensor([[471.4599,  24.6294],
        [504.2861,  18.5983]], device='cuda:0')
3D Points seen by Camera 1: tensor([[ 0.4775, -0.9425,  1.7001],
        [-1.4280,  2.1848,  1.3098],
        [-0.2998,  0.6578,  1.4361],
        ...,
        [-0.8882,  1.9069, -1.5784],
        [ 0.8375, -1.0456,  1.2880],
        [-1.1577,  0.1336,  1.5065]])
Corresponding 2D pixel coordinates: tensor([[169.3252,   5.1960],
        [138.5088,   5.8661],
        [141.8584,   6.5362],
        ...,
        [121.7607, 499.7419],
        [379.6807, 499.7419],
        [288.5713, 507.1132]], device='cuda:0')


In [41]:
class SceneGraph(nn.Module):
    def __init__(self, cam_model, matcher, max_feature_points=2000, device='cpu'):
        super(SceneGraph, self).__init__()
        self.num_cams = 0
        self.num_points = 0
        self.feature_points = dict()
        self.matches = defaultdict(dict)
        self.cam_model = cam_model.to(device)
        self.matcher = matcher

        self.graph = HeteroData(device=device)
        self.graph['camera'].x = pp.SE3(torch.empty(0, 7), device=device)
        self.graph['point'].x = torch.empty((0,3), device=device)
        self.connected_components = torch.empty((0, max_feature_points), device=device) # Does this need to be on gpu, i dont think so.

    def add_new_batch(self, init_poses, feature_points, depths):
        new_cams = self.add_feature_points(feature_points)
        sequential_cams = self.perform_sequential_matching(new_cams)
        self.update_connected_components(len(new_cams), sequential_cams)
        self.register_to_graph(new_cams)

    def add_feature_points(self, feature_points):
        start_cam = self.num_cams + 1
        end_cam = start_cam + len(feature_points)
        new_cams = list(range(start_cam, end_cam))

        self.num_cams += len(feature_points)

        for i, cam in enumerate(new_cams):
            self.feature_points[cam] = feature_points[i]['keypoints'][0]

        return new_cams

    def perform_sequential_matching(self, new_cams):
        sequential_cams = zip(new_cams[:-1], new_cams[1:])
        for cam1, cam2 in sequential_cams:
            _, _, matches01 = match_keypoints(self.feature_points[cam1], self.feature_points[cam2], self.matcher)
            self.matches[cam1][cam2] = matches01["matches"]
        return sequential_cams

    def update_connected_component_labels(self, num_new_cams, cam_pairs):
        self.connected_components = torch.cat([
            self.connected_components, 
            torch.zeros((len(num_new_cams), self.connected_components.shape[1]), device=self.connected_components.device)])

        for cam1, cam2 in cam_pairs:
            match = self.matches[cam1][cam2]
            for u, v in match:
                label1 = connected_components[cam1, u]
                label2 = connected_components[cam2, v]
                if label1 == label2: # Should check if they are both zeros
                    self.num_points += 1
                    self.connected_components[cam1, u] = self.num_points
                    self.connected_components[cam2, v] = self.num_points
                else:   # Should use set the zero-valued label with the non-zero label.
                    if label2 == 0: 
                        self.connected_components[cam2, v] = label1
                    else:
                        self.connected_components[cam1, u] = label2

    def register_to_graph(self, new_cams, init_poses, depths):
        non_zero = self.connected_components.nonzero()
        camera_indices = non_zero[:,0]
        point_indices = self.connected_components[non_zero[:,0], non_zero[:,1]]
        self.num_points = len(point_indices)

        num_new_points = self.num_points - self.graph['point'].x.shape[0]
        self.graph['camera'].x = torch.cat([self.graph['camera'].x, init_poses])
        self.graph['point'].x = torch.cat([self.graph['point'].x, torch.zeros((num_new_points,3), device=self.graph['point'].x.device)])
        self.triangulate_new_points(new_cams, depths)

        self.graph[('camera', 'sees', 'point')].edge_index = torch.vstack([camera_indices, point_indices]).int()
        self.graph[('camera', 'sees', 'point')].edge_attr = torch.vstack([self.feature_matches[cam][point] for cam, point in non_zero])

    def triangulate_new_points(self, new_cams, depths):
        for cam in new_cams:
            point_indices, keypoints = self.get_triangulation_data(cam)
            pose = self.poses[cam]
            depths = depths[cam, keypoints[:,1], keypoints[:, 0]]
            unit_vectors = F.normalize(self.camera_model.unproject_points(keypoints))
            cam_sees_3dpts = pose @ (unit_vectors * depths.unsqueeze(1))
            self.graph['point'].x[point_indices] = cam_sees_3dpts
    
    def get_triangulation_data(self, camera_id):
        """ Retrieves camera_ids and 2D points seen by 'camera_id'. """
        # Find edges originating from the given camera
        edges = (self.graph[('camera', 'sees', 'point')].edge_index[0] == camera_id).nonzero().view(-1)
        
        # Get indices of points seen by the camera
        point_indices = self.graph[('camera', 'sees', 'point')].edge_index[1][edges]
        
        # Retrieve 3D points and 2D pixel coordinates
        points_2d = self.graph[('camera', 'sees', 'point')].edge_attr[edges]
        
        return point_indices, points_2d
    
    def get_camera_data(self, camera_id):
        """ Retrieves 3D and 2D points seen by 'camera_id'. """
        # Find edges originating from the given camera
        edges = (self.graph[('camera', 'sees', 'point')].edge_index[0] == camera_id).nonzero().view(-1)
        
        # Get indices of points seen by the camera
        point_indices = self.graph[('camera', 'sees', 'point')].edge_index[1][edges]
        
        # Retrieve 3D points and 2D pixel coordinates
        points_3d = self.graph['point'].x[point_indices]
        points_2d = self.graph[('camera', 'sees', 'point')].edge_attr[edges]
        
        return point_indices, points_3d, points_2d

    def get_point_data(self, point_id):
        """ Retrieves the ids, poses, and pixel coordinates of all cameras which see 'point_id'. """
        # Find edges that terminate at the given point
        edges = (self.graph[('camera', 'sees', 'point')].edge_index[1] == point_id).nonzero().view(-1)
        
        # Get indices of cameras that see the point
        camera_indices = self.graph[('camera', 'sees', 'point')].edge_index[0][edges]
        
        # Retrieve the poses of the cameras
        camera_poses = self.graph['camera'].x[camera_indices]
        
        # Retrieve the 2D pixel coordinates associated with the point in each camera
        pixel_coords = self.graph[('camera', 'sees', 'point')].edge_attr[edges]
        
        return camera_indices, camera_poses, pixel_coords

    def forward(self, graph):
        # Calculate reprojection error, does not need to be implemented.
        pass
