In [1]:
import torch
from torch_geometric.data import HeteroData

# Initialize HeteroData object
graph = HeteroData()

# Assume num_cameras cameras and num_points 3D points
num_cameras = 1000
num_points = int(1e6)
num_edges = num_points + int(1e6)

# Initialize cameras and points with some example features (e.g., pose for cameras, coordinates for points)
graph["camera"].x = torch.randn(num_cameras, 7)  # Example: 6D pose (position + orientation) for each camera
graph["point"].x = torch.randn(num_points, 3)  # 3D coordinates for each point

# Assuming we have some relationship between cameras and points, represented by edges
# For simplicity, let's create random connections with random 2D pixel coordinates as edge attributes

# Camera -> Point connections have shape [2, N]
edge_index = torch.randint(0, num_cameras, (2, num_edges), dtype=torch.long)  # Randomly connect cameras and points
# Pixel coordinates are edge attributes, shape [N, 2]
edge_attr = torch.rand(num_edges, 2)

# Add edges to the graph (from 'camera' to 'point' with type 'sees')
graph[("camera", "sees", "point")].edge_index = edge_index
graph[("camera", "sees", "point")].edge_attr = edge_attr  # 2D points as edge attributes


# Function to retrieve 3D and 2D points seen by a specific camera
def get_camera_data(graph, camera_id):
    # Find edges originating from the given camera
    edges = (graph[("camera", "sees", "point")].edge_index[0] == camera_id).nonzero().view(-1)

    # Get indices of points seen by the camera
    point_indices = graph[("camera", "sees", "point")].edge_index[1][edges]

    # Retrieve 3D points and 2D pixel coordinates
    points_3d = graph["point"].x[point_indices]
    points_2d = graph[("camera", "sees", "point")].edge_attr[edges]

    return points_3d, points_2d


# Example usage: Get data for camera 1
points_3d, points_2d = get_camera_data(graph, camera_id=999)
print("3D Points seen by Camera 1:", points_3d.shape)
print("Corresponding 2D pixel coordinates:", points_2d.shape)

3D Points seen by Camera 1: torch.Size([1995, 3])
Corresponding 2D pixel coordinates: torch.Size([1995, 2])


In [6]:
for i in range(10):
    A = torch.randn(1000, 2)
    B = torch.randn(1000, 2)

    A[[100, 200, 600], :] = B[[100, 499, 600]]

    values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
    indices = indices[values != 0]
    print(indices)

tensor([100, 200, 600])
tensor([100, 200, 600])
tensor([100, 200, 600])
tensor([100, 200, 600])
tensor([100, 200, 600])
tensor([100, 200, 600])
tensor([100, 200, 600])
tensor([100, 200, 600])
tensor([100, 200, 600])
tensor([100, 200, 600])


In [7]:
import torch
import torch.nn as nn
from torch_geometric.data import HeteroData


class SceneGraph(nn.Module):
    def __init__(self, init_poses, depths, all_matches):
        super(SceneGraph, self).__init__()
        # Initialize the graph from matches
        self.graph = HeteroData()
        self.graph["camera"].x = nn.Parameter(init_poses)
        self.populate_graph(all_matches)

    def populate_graph(self, all_matches):
        # Iterates over the matches and registers_3D points.

    def register_3Dpoints(self, cam0, cam1):
        # Registers 3D points for specified cameras, which might already exist in the graph.

    def match_indices(self, A, B):
        # Finds A[indices] == B for 2D tensors A and B.
        values, indices = torch.topk(((A.t() == B.unsqueeze(-1)).all(dim=1)).int(), 1, 1)
        indices = indices[values != 0]
        return indices

    def forward(self, graph):
        # Calculate reprojection error, does not need to be implemented.
        pass