In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import cv2
import pickle
from pathlib import Path
from tqdm.notebook import tqdm
from PIL import Image
from hloc.utils.read_write_model import read_images_binary, read_points3D_binary, read_cameras_binary, qvec2rotmat
from segment_anything import sam_model_registry, SamPredictor

# Configure tqdm for notebook
tqdm.pandas()

In [2]:
IMAGES_PATH = 'outputs/aachen/sfm_superpoint+superglue/images.bin'
POINTS3D_PATH = 'outputs/aachen/sfm_superpoint+superglue/points3D.bin'
CAMERAS_PATH = 'outputs/aachen/sfm_superpoint+superglue/cameras.bin'
SAM_CHECKPOINT = "segment-anything-main/sam_vit_h_4b8939.pth"
MODEL_TYPE = "vit_h"
OUTPUT_PATH = Path('DataBase')
PATCH_SIZE = 14  # Grid size for patches
EMBEDDING_DIM = 192  # Dimension for rotary embeddings

In [3]:
def initialize_models_and_data():
    """Initialize SAM model and load necessary data"""
    # Load COLMAP data
    images = read_images_binary(IMAGES_PATH)
    points3Ds = read_points3D_binary(POINTS3D_PATH)
    cameras = read_cameras_binary(CAMERAS_PATH)
    
    # Initialize SAM
    device = "cuda"
    sam = sam_model_registry[MODEL_TYPE](checkpoint=SAM_CHECKPOINT)
    sam.to(device=device)
    predictor = SamPredictor(sam)
    
    return images, points3Ds, cameras, predictor

images, points3Ds, cameras, predictor = initialize_models_and_data()

In [4]:
def extract_patch_2D_points(xys, img_width, img_height, size):
    """Extract 2D points for each patch in the image grid"""
    patch_width = img_width // size
    patch_height = img_height // size
    
    patches_2D_points = {}
    for i in range(size):
        for j in range(size):
            x_start, x_end = i * patch_width, (i + 1) * patch_width
            y_start, y_end = j * patch_height, (j + 1) * patch_height
            
            mask = (xys[:, 0] >= x_start) & (xys[:, 0] < x_end) & \
                   (xys[:, 1] >= y_start) & (xys[:, 1] < y_end)
            
            patches_2D_points[(j, i)] = xys[mask].tolist()
    
    return dict(sorted(patches_2D_points.items()))

def convert_2D_to_3D(point_2d, K, R, t):
    """Convert 2D point to 3D world coordinates"""
    RT_4x4 = np.eye(4)
    RT_4x4[:3, :3] = R
    RT_4x4[:3, 3] = t.reshape(-1)
    
    point_2d = np.array([int(point_2d[0]), int(point_2d[1]), 1]).reshape(3, 1)
    point_3d_cam = np.linalg.inv(K).dot(point_2d)
    point_3d_cam = point_3d_cam / np.linalg.norm(point_3d_cam)
    
    point_3d_world = np.linalg.inv(RT_4x4).dot(np.vstack([point_3d_cam * 10, 1]))
    return point_3d_world[:3].reshape(-1)

In [5]:
def process_image(image_id, image, predictor):
    """Process single image and extract patch information"""
    xys = image.xys
    img = Image.open(f'datasets/aachen/images/images_upright/{image.name}')
    img_width, img_height = img.size
    
    # Extract patch information
    patches_2D = extract_patch_2D_points(xys, img_width, img_height, PATCH_SIZE)
    
    # Load and process image with SAM
    img_cv = cv2.imread(f'datasets/aachen/images/images_upright/{image.name}')
    img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
    predictor.set_image(img_rgb)
    
    # Process patches
    patch_info = {}
    for key, patches in patches_2D.items():
        x_start = key[1] * (img_width // PATCH_SIZE)
        y_start = key[0] * (img_height // PATCH_SIZE)
        center_x = x_start + (img_width // (2 * PATCH_SIZE))
        center_y = y_start + (img_height // (2 * PATCH_SIZE))
        
        patch_info[key] = [center_x, center_y]
    
    return patch_info

TotalImageKey = {}
for image_id, image in tqdm(images.items(), desc="Processing images"):
    TotalImageKey[image_id] = process_image(image_id, image, predictor)

Processing images:   0%|          | 0/4328 [00:00<?, ?it/s]

In [6]:
def convert_patches_to_3D(image_id, patches, images, cameras):
    """Convert patch centers to 3D coordinates"""
    image = images[image_id]
    camera = cameras[image.camera_id]
    
    # Camera parameters
    K = np.array([
        [camera.params[0], 0, camera.params[1]],
        [0, camera.params[0], camera.params[2]],
        [0, 0, 1]
    ])
    R = qvec2rotmat(image.qvec)
    t = image.tvec
    
    patches_3D = {'Token': torch.tensor(np.concatenate([t, image.qvec]))}
    
    for key, point_2d in patches.items():
        if point_2d == [0, 0]:
            patches_3D[key] = torch.tensor([100000, 100000, 100000, *image.qvec])
        else:
            point_3d = convert_2D_to_3D(point_2d, K, R, t)
            patches_3D[key] = torch.tensor(np.concatenate([point_3d, image.qvec]))
    
    return patches_3D

In [7]:
TotalImageKey3D = {}
for image_id in tqdm(TotalImageKey.keys(), desc="Converting to 3D"):
    TotalImageKey3D[image_id] = convert_patches_to_3D(
        image_id, TotalImageKey[image_id], images, cameras)

# Save intermediate results
with open(OUTPUT_PATH / "TotalImageKey3D_14x14.pickle", "wb") as f:
    pickle.dump(TotalImageKey3D, f)

Converting to 3D:   0%|          | 0/4328 [00:00<?, ?it/s]

In [8]:
def rotary_position_embeddings_7d_to_192d(coords):
    """Generate rotary position embeddings from 7D to 192D"""
    # Normalize coordinates to [-1, 1]
    coords = 2.0 * (coords - coords.min()) / (coords.max() - coords.min()) - 1.0
    angles = coords * torch.pi
    
    # Compute sin and cos
    sin = torch.sin(angles)
    cos = torch.cos(angles)
    
    # Create embeddings
    embeddings = torch.zeros(*coords.shape[:-1], EMBEDDING_DIM)
    full_blocks = EMBEDDING_DIM // 14
    
    for i in range(full_blocks):
        embeddings[..., 14*i:14*i+7] = sin
        embeddings[..., 14*i+7:14*i+14] = cos
    
    # Handle remaining dimensions
    remaining_dims = EMBEDDING_DIM % 14
    if remaining_dims > 0:
        extended_sin_cos = torch.cat((sin, cos), dim=-1).flatten()
        embeddings[..., -remaining_dims:] = extended_sin_cos[:remaining_dims]
    
    return embeddings

In [9]:
# Generate embeddings for all points
encoded_points = {}
for id, reps in tqdm(TotalImageKey3D.items(), desc="Generating embeddings"):
    encoded_points[id] = {}
    for key, value in reps.items():
        if value is not None:
            input_tensor = value.unsqueeze(0)
            output = rotary_position_embeddings_7d_to_192d(input_tensor)
            encoded_points[id][key] = output.squeeze(0)

Generating embeddings:   0%|          | 0/4328 [00:00<?, ?it/s]

In [10]:
tensor_dict = {}
for key in tqdm(encoded_points, desc="Creating final tensors"):
    token_tensor = encoded_points[key]['Token']
    patch_tensors = [encoded_points[key][k] for k in encoded_points[key] if k != 'Token']
    tensor_dict[key] = torch.stack([token_tensor] + patch_tensors)

# Save results
with open(OUTPUT_PATH / 'Large_Patch_14x14_RT_RoPE_Tensor.pickle', 'wb') as f:
    pickle.dump(tensor_dict, f, pickle.HIGHEST_PROTOCOL)

Creating final tensors:   0%|          | 0/4328 [00:00<?, ?it/s]