In [None]:
import csv
import math

from typing import List

import numpy as np
from matplotlib import pyplot as plt
# import the ZOD DevKit
from zod import ZodFrames
from zod import ZodSequences
from datetime import datetime, timezone

# import default constants
import zod.constants as constants
from zod.constants import Camera, Lidar, Anonymization, AnnotationProject

# import useful data classes
from zod.data_classes import LidarData

#for loading zod data
zod_dataset = "/home/bjangley/VPR/ZOD/full"  # your local path to zod
version = "mini"  # "mini" or "full"

# initialize ZodSequences
zod_sequences = ZodSequences(dataset_root=zod_dataset, version=version)
zod_000002 = zod_sequences['000002'] #getting a specific sequence

#for mapillary data
mapillary_metadata = '/home/bjangley/VPR/mapillary_utils/zod_000002/metadata.csv'




def haversine(lat1, lon1, lat2, lon2):
    r = 6371000  # Radius of Earth in meter
    phi1 = math.radians(lat1)
    phi2 = math.radians(lat2)
    delta_phi = math.radians(lat2 - lat1)
    delta_lambda = math.radians(lon2 - lon1)

    a = math.sin(delta_phi / 2)**2 + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2)**2
    c = 2 * math.asin(math.sqrt(a))
    
    return r * c

def sort_images_by_distance(csv_file_path, target_lat, target_lng):
    images = []
    
    with open(csv_file_path, mode='r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            lat = float(row['lat'])
            lng = float(row['long'])
            distance = haversine(target_lat, target_lng, lat, lng)
            images.append({'image_id': row['id'], 'distance': distance})
    
    return sorted(images, key=lambda x: x['distance'])


#sorting mapillary images based on their distance to the keyframe

keyframe_lat = zod_000002.metadata.latitude
keyframe_lon = zod_000002.metadata.longitude
sorted_images = sort_images_by_distance(mapillary_metadata, keyframe_lat, keyframe_lon)








Each ZOD frame is an instance of `<class 'zod.data_classes.sensor.CameraFrame'>` which has the attributes `filepath`, `time`, `height` and `width`.

In [None]:
#for a specific sequence i.e. 000002 we can access all camera_frames like so:  
#this will return a list of instances of the cameraFrame class from zod
zod_000002_camera_frames = zod_000002.info.get_camera_frames()
print(f"Number of camera frames: {len(zod_000002.info.get_camera_frames())}")


#accessing the keyframe by 
keyframe = zod_000002.info.get_key_camera_frame()
keyframe_image = keyframe.read()
plt.imshow(keyframe_image)
plt.show()

#find keyframe time by
print(zod_000002.info.keyframe_time.timestamp())

#start time and endtime
print('Starttime: ', zod_000002.info.start_time, 'End-time: ', zod_000002.info.end_time)




In [None]:
#need to index the frames for convinience, i think it's already sorted tbh but idk in which order, assuming the keyframe is
#always going to be the 100th frame

# sorted_frames = sorted(zod_000002_camera_frames, key=lambda frame: frame.time.timestamp())
# print(sorted_frames)

# keyframe_time = zod_000002.info.keyframe_time.timestamp()
# print((keyframe_time))

# # Step 3: Find the closest frame and its index
# closest_frame_index, closest_frame = min(
#     enumerate(sorted_frames),
#     key=lambda item: abs(item[1].time.timestamp() - keyframe_time)
# )

# print(f"Closest Frame Index: {closest_frame_index}")
# print(f"Closest Frame: {closest_frame}")

In [None]:
#to find the pose at a specific time, we have to use the egomotion attribute of the sequence class 
#and parse the time argument

np.set_printoptions(precision=2, suppress=True)



camera_frame = zod_000002_camera_frames[-1]
camera_frame_time = camera_frame.time
print(camera_frame_time, type(camera_frame_time) )

#get poses for a single timestamp
poses = zod_000002.ego_motion.get_poses(np.array([camera_frame_time.timestamp()]))

print("End-time Pose: \n", poses)

poses = zod_000002.ego_motion.get_poses(np.array([zod_000002.info.keyframe_time.timestamp()]))

print("KeyFrame Pose: \n", poses)

#this would be the same as zod_000002.ego_motion.poses[0]
#poses = zod_000002.ego_motion.get_poses(np.array([zod_000002.info.start_time.timestamp()])) 
poses = zod_000002.ego_motion.poses[180]
print("Start-time Poses: \n", poses)

#to find the pose for the nth frame --> find the timestamp via frame and then access the pose via ego motion
n = 99
nth_frame = zod_000002_camera_frames[99]
nth_frame_time = nth_frame.time.timestamp()
nth_frame_pose = zod_000002.ego_motion.get_poses(np.array([nth_frame_time]))
print(n, "th frame's pose \n", nth_frame_pose)

In [None]:
# np.set_printoptions(precision=2, suppress=True)
camera_matrix = zod_000002.calibration.cameras[Camera.FRONT].intrinsics[:,0:3]
extrinsics = zod_000002.calibration.cameras[Camera.FRONT].extrinsics
distortion = zod_000002.calibration.cameras[Camera.FRONT].distortion
undistortion = zod_000002.calibration.cameras[Camera.FRONT].undistortion
print(camera_matrix, distortion, undistortion)
print(extrinsics)


In [None]:
#supressing unnecessary warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)



import matplotlib.pyplot as plt
import os
from PIL import Image


#ensuring i dont exceed cpu limits?
import os
os.environ["MKL_NUM_THREADS"] = "8"
os.environ["NUMEXPR_NUM_THREADS"] = "8"
os.environ["OMP_NUM_THREADS"] = "8"


from utils import (
    getMasterOutout,
    scale_intrinsics,
    CameraMatrix,
    run_pnp,
    getImageFromIndex,
)

keyframe = zod_000002.info.get_key_camera_frame() #serving as anchor
query = zod_000002_camera_frames[99]

keyframe_image = keyframe.read()
query_image = query.read()

# Create a figure with 1 row and 2 columns
fig, axarr = plt.subplots(1, 2, figsize=(10, 5))  # Adjust figsize as needed

# Plot the keyframe image
axarr[0].imshow(keyframe_image)
axarr[0].set_title("Anchor Image")  # Set title for keyframe image
axarr[0].axis('off')  # Hide axis

# Plot the query image
axarr[1].imshow(query_image)
axarr[1].set_title("Query Image")  # Set title for query image
axarr[1].axis('off')  # Hide axis

# Show the plot
plt.tight_layout()  # Adjust layout to prevent overlap
plt.show()

master_size = [512,384] #size of image used by mast3r
n_matches = 200
K_scaled = scale_intrinsics(camera_matrix,query.width, query.height, master_size[0],master_size[1])

In [None]:
# print(query.height, query.width)
# print(camera_matrix)

In [None]:
filtered_matches_im0,filtered_matches_im1,matches_im0, matches_im1, pts3d_im0, pts3d_im1, conf_im0, conf_im1, desc_conf_im0, desc_conf_im1 = getMasterOutout(query.filepath, keyframe.filepath, n_matches,visualizeMatches=False)

import numpy as np

def xy_grid(W, H, origin=(0, 0), homogeneous=False):
    """ Create a (H, W, 2) array of pixel coordinates.
        If homogeneous is True, adds a third dimension with ones.
    """
    x = np.arange(origin[0], origin[0] + W)
    y = np.arange(origin[1], origin[1] + H)
    grid_x, grid_y = np.meshgrid(x, y)
    
    if homogeneous:
        grid = np.stack((grid_x, grid_y, np.ones_like(grid_x)), axis=-1)
    else:
        grid = np.stack((grid_x, grid_y), axis=-1)
    
    return grid

def estimate_focal_knowing_depth(pts3d, focal_mode='median', min_focal=0., max_focal=np.inf):
    """ Estimate the camera focal length using reprojection method. """
    W, H, _ = pts3d.shape
    pp = np.array([W / 2, H / 2])  # Principal point

    # Centered pixel grid
    pixels = xy_grid(W, H) - pp  # Shape: (H, W, 2)

    # Flatten points for easier processing
    pixels = pixels.reshape(-1, 2)  # Shape: (HW, 2)
    pts3d_flat = pts3d.reshape(-1, 3)  # Shape: (HW, 3)

    if focal_mode == 'median':
        # Direct estimation of focal
        u, v = pixels[:, 0], pixels[:, 1]
        x, y, z = pts3d_flat[:, 0], pts3d_flat[:, 1], pts3d_flat[:, 2]
        
        fx_votes = (u * z) / x
        fy_votes = (v * z) / y

        # Assume square pixels; hence same focal for X and Y
        f_votes = np.concatenate((fx_votes[np.newaxis], fy_votes[np.newaxis]), axis=0)
        focal = np.nanmedian(f_votes)

    elif focal_mode == 'weiszfeld':
        # Initialize focal with L2 closed form
        xy_over_z = pts3d_flat[:, :2] / pts3d_flat[:, 2][:, np.newaxis]  # Shape: (HW, 2)
        xy_over_z[np.isnan(xy_over_z)] = 0  # Handle NaN values

        dot_xy_px = (xy_over_z * pixels).sum(axis=1)
        dot_xy_xy = (xy_over_z ** 2).sum(axis=1)

        focal = dot_xy_px.mean() / dot_xy_xy.mean()

        # Iterative re-weighted least-squares
        for _ in range(10):
            dis = np.linalg.norm(pixels - focal * xy_over_z, axis=1)
            w = np.clip(1 / dis, a_min=1e-8, a_max=None)  # Avoid division by zero
            
            # Update the scaling with the new weights
            focal = (w * dot_xy_px).mean() / (w * dot_xy_xy).mean()
    else:
        raise ValueError(f'Invalid focal_mode: {focal_mode}')

    # Clip the focal length based on min and max limits
    focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2))  
    focal = np.clip(focal, min_focal * focal_base, max_focal * focal_base)

    print(f"Estimated Focal Length: {focal}")
    return focal

# Example usage:
# Assuming pts3d_im0 is your point map with shape (W, H, 3)
# pts3d_im0 = np.random.rand(480, 640, 3)  # Example initialization
focal_length = estimate_focal_knowing_depth(pts3d_im0, focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5)



In [None]:
filtered_matches_im0,filtered_matches_im1,matches_im0, matches_im1, pts3d_im0, pts3d_im1, conf_im0, conf_im1, desc_conf_im0, desc_conf_im1 = getMasterOutout(keyframe.filepath, query.filepath, n_matches,visualizeMatches=True)
W, H, _ = pts3d_im0.shape
K_new = CameraMatrix(focal_length, focal_length, W/2, H/2)
# Predicted Transform copied from visloc.py
ret_val, transformation = run_pnp(filtered_matches_im1.astype(np.float32), pts3d_im0[filtered_matches_im0[:, 1], filtered_matches_im0[:, 0], :].astype(np.float32), K_new.astype(np.float32),distortion)
print(transformation)
ret_val, transformation = run_pnp(filtered_matches_im1.astype(np.float32), pts3d_im0[filtered_matches_im0[:, 1], filtered_matches_im0[:, 0], :].astype(np.float32), K_scaled.astype(np.float32),distortion)
print(transformation)

In [None]:
print(K_scaled)
print(K_new)

In [None]:
pointmap = pts3d_im0[filtered_matches_im0[:, 1], filtered_matches_im0[:, 0], :] #thresholded
pointmap_allmatches = pts3d_im0[matches_im0[:, 1], matches_im0[:, 0], :] #unthresholded
print(pointmap.shape)


keyframe_lidar = zod_000002.get_keyframe_lidar()
print(keyframe_lidar)
lidar_pointcloud = keyframe_lidar.points


plt.figure(figsize=(10, 6))


plt.scatter(lidar_pointcloud[:, 0], lidar_pointcloud[:, 1], color='red', alpha=0.1,label='LiDAR (from ZOD)')
plt.scatter(pointmap_allmatches[:, 0], pointmap_allmatches[:, 2], color='green',alpha=0.5, label='Entire MASt3R Pointmap')
plt.scatter(pointmap[:, 0], pointmap[:, 2], color='blue', alpha=0.7,label='Filtered MASt3R Pointmap')

# Highlight the origin
plt.scatter(0, 0, color='black', s=10, label='Camera', edgecolor='black')


# plt.xlim(min(pointmap[:, 0].min(), -1), max(pointmap[:, 0].max(), 5))  # Adjust limits as needed
# plt.ylim(min(pointmap[:, 2].min(), -1), max(pointmap[:, 2].max(), 5))  # Adjust limits as needed

plt.axhline(0, color='gray', linewidth=0.5, linestyle='--')
plt.axvline(0, color='gray', linewidth=0.5, linestyle='--')

# Add titles and labels
plt.title('2D Plot of X vs Z')
plt.xlabel('X Coordinate')
plt.ylabel('Z Coordinate')
plt.grid(True)
plt.legend()
plt.show()


# plt.figure(figsize=(10, 6))
# #plt.scatter(pointmap_allmatches[:, 0], pointmap_allmatches[:, 2], color='green',alpha=0.5, label='All matches (x, z)')
# plt.scatter(lidar_pointcloud[:, 0], lidar_pointcloud[:, 1], color='red', alpha=0.7,label='Filtered Matches (x, z)')
# # Highlight the origin
# plt.scatter(0, 0, color='black', s=10, label='Origin (0, 0)', edgecolor='black')
# plt.axhline(0, color='gray', linewidth=0.5, linestyle='--')
# plt.axvline(0, color='gray', linewidth=0.5, linestyle='--')

# # Add titles and labels
# plt.title('2D Plot of X vs Z')
# plt.xlabel('X Coordinate')
# plt.ylabel('Z Coordinate')
# plt.grid(True)
# plt.legend()
# plt.show()




In [None]:
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

# Create an output widget to display the plot
output = widgets.Output()

# Create the figure outside the update function
fig, ax = plt.subplots(figsize=(10, 6))
plt.close(fig)  # Prevent immediate display

def update_plot(scale_factor):
    # Clear previous plot
    with output:
        output.clear_output(wait=True)
        
        # Recreate the plot each time
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Scale the pointmap
        scaled_pointmap = pointmap * scale_factor
        scaled_allmatches = pointmap_allmatches*scale_factor
        # Plot LiDAR point cloud
        ax.scatter(lidar_pointcloud[:, 0], lidar_pointcloud[:, 1], 
                   color='red', alpha=0.1, label='LiDAR (from ZOD)')
        

        ax.scatter(scaled_allmatches[:,0], scaled_allmatches[:,2],color='green',alpha=0.6,label="Scaled Complete mast3r pointmap")
        # Plot scaled pointmap
        ax.scatter(scaled_pointmap[:, 0], scaled_pointmap[:, 2], 
                   color='blue', alpha=0.7, label='Scaled Filtered mast3r pointmap')
        
        
        
        # Origin point
        ax.scatter(0, 0, color='black', s=50, label='Camera')
        
        # Set consistent plot limits
        ax.set_xlim(min(pointmap[:, 0].min(), -1), max(pointmap[:, 0].max(), 9))
        ax.set_ylim(min(pointmap[:, 2].min(), -1), max(pointmap[:, 2].max(), 5))
        
        # Grid and lines
        ax.axhline(0, color='gray', linewidth=0.5, linestyle='--')
        ax.axvline(0, color='gray', linewidth=0.5, linestyle='--')
        
        ax.set_title(f'Pointmap Scaling (Factor: {scale_factor:.2f})')
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Z Coordinate')
        ax.grid(True)
        ax.legend()
        
        plt.tight_layout()
        plt.show()

# Create slider widget
scale_slider = widgets.FloatSlider(
    value=1.0,
    min=0.1,
    max=5.0,
    step=0.1,
    description='Scale:',
    continuous_update=False
)

# Link slider to plot update
widgets.interactive(update_plot, scale_factor=scale_slider)

# Display widgets
display(scale_slider, output)
