# ShapeR Evaluation Data Exploration

This notebook explores the data format of the pickles for the ShapeR Evaluation Dataset.

We'll examine:
1. Data structure for the pickle files containing the samples
2. Point cloud data
3. Image data
4. Object point projections
5. Camera poses and camera intrinsics
6. Ground-truth mesh
7. Captions
8. Dataset visualization for both RGB and non-RGB (SLAM) variants
9. Dataloader usage example

## 1. Pickle File Structure Reference

The following tables describe the structure of the pickle files containing a sample's data. All samples share the same structure.

**Variable Definitions:**
- $N$ = number of semi-dense 3D points from SLAM
- $V$ = number of mesh vertices
- $F$ = number of mesh faces (triangles)
- $I_s$ = number of SLAM camera images
- $I_r$ = number of RGB camera images
- $P_i^s$ = number of points visible in SLAM image $i$ (varies per image)
- $P_i^r$ = number of points visible in RGB image $i$ (varies per image)

### Point Cloud and Bounding Box Data

| Key | Description | Dimensions |
|-----|-------------|------------|
| `points_model` | Semi-dense 3D point cloud from SLAM reconstruction, in object/model coordinate frame | $(N, 3)$ |
| `bounds` | Axis-aligned bounding box half-extents of the object | $(3,)$ |
| `T_model_world` | Transformation from world frame to model/object frame, `T_zup_obj` is not used. <br/> Refer to `dataset.shaper_dataset.InferenceDataset.rescale_back` for usage | $(4, 4)$ |
| `inv_dist_std` and `dist_std` | Placeholder for point quality metric (unused, zeroed out in this dataset) | $(N,)$ |

### SLAM Camera Data

| Key | Description | Dimensions |
|-----|-------------|------------|
| `image_data` | JPEG-encoded grayscale images from SLAM camera | list[$I_s$] of bytes |
| `Ts_camera_model` | Transformation from model/object frame to SLAM camera frame (camera extrinsics) | $(I_s, 4, 4)$ |
| `camera_params` | SLAM camera intrinsics (FISHEYE624 distortion model, 15 parameters) | $(I_s, 15)$ |
| `object_point_projections` | 2D pixel coordinates of visible 3D points projected onto each SLAM image | list[$I_s$], each $(P_i^s, 2)$ |
| `visible_points_model` | 3D coordinates of points visible in each SLAM image | list[$I_s$], each $(P_i^s, 3)$ |

### RGB Camera Data

| Key | Description | Dimensions |
|-----|-------------|------------|
| `rgb_image_data` | JPEG-encoded RGB images from RGB camera | list[$I_r$] of bytes |
| `Ts_rgbCamera_model` | Transformation from model/object frame to RGB camera frame | $(I_r, 4, 4)$ |
| `rgb_camera_params` | RGB camera intrinsics (FISHEYE624 distortion model, 15 parameters) | $(I_r, 15)$ |
| `rgb_object_point_projections` | 2D pixel coordinates of points projected onto each RGB image | list[$I_r$], each $(P_i^r, 2)$ |
| `rgb_visible_points_model` | 3D coordinates of points visible in each RGB image | list[$I_r$], each $(P_i^r, 3)$ |

### Ground Truth Mesh

| Key | Description | Dimensions |
|-----|-------------|------------|
| `mesh_vertices` | Ground truth mesh vertex positions | $(V, 3)$ |
| `mesh_faces` | Ground truth mesh face indices (triangles) | $(F, 3)$ |

### Other

| Key | Description | Dimensions |
|-----|-------------|------------|
| `caption` | Text description: category + shape description. Falls back to `category` if not present | str |
| `is_ariagen2` | Device flag: `True` = Aria Gen2, `False` = Aria Gen1 | bool |

## 2. Point Cloud Data

The point cloud is stored in `points_model`. The points are zero centered, but in metric world scale. To place them in their world position use, refer to how `T_model_world` is used in `dataset.shaper_dataset.InferenceDataset.rescale_back`. Lets check out the point clouds for one of the samples.

In [None]:
import pickle
import io
import os
import random

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from dataset.download import setup_data

# Set default figure size
plt.rcParams['figure.figsize'] = [12, 8]

In [None]:
# Default pickle file (same as infer_shape.py default)
PKL_NAME = "BNB2909__pitcher.pkl"
PKL_PATH = f"data/{PKL_NAME}"

setup_data(PKL_NAME)

# Load the pickle file
with open(PKL_PATH, "rb") as f:
    pkl_sample = pickle.load(f)

# Extract point cloud
points = pkl_sample["points_model"].numpy()
bounds = pkl_sample["bounds"].numpy()

# Quality metrics (lower is better for filtering)
inv_dist_std = pkl_sample["inv_dist_std"].numpy()  # theta
dist_std = pkl_sample["dist_std"].numpy()  # phi

print(f"Point cloud shape: {points.shape}")
print(f"Bounds: {bounds}")
print(f"Point range: [{points.min():.3f}, {points.max():.3f}]")

In [None]:
# Interactive point cloud visualization using plotly

# Subsample for performance
idx = np.random.choice(len(points), min(5000, len(points)), replace=False)
pts = points[idx]

fig = make_subplots(
    rows=1, cols=1,
    specs=[[{'type': 'scatter3d'}]],
    subplot_titles=('Point Cloud',)
)

# Full point cloud
fig.add_trace(
    go.Scatter3d(
        x=pts[:, 0], y=pts[:, 1], z=pts[:, 2],
        mode='markers',
        marker=dict(size=2, color='blue', opacity=0.6),
        name='Points'
    ),
    row=1, col=1
)

fig.update_layout(
    title=f"Point Cloud Visualization ({len(points)} points, showing {len(idx)} subsampled)",
    height=500,
    showlegend=False,
)

fig.show()

## 3. Image Data (SLAM Views)

Images are stored as encoded bytes (JPEG/PNG). This contains all the images where the sample object was seen in the sequence. Let's decode and visualize them.

In [None]:
image_data = pkl_sample["image_data"]
print(f"Number of SLAM images: {len(image_data)}")

# Decode first image to check format
first_img = Image.open(io.BytesIO(image_data[0]))
print(f"Image mode: {first_img.mode}")
print(f"Image size: {first_img.size}")

rgb_image_data = pkl_sample["rgb_image_data"]
print(f"\nNumber of RGB images: {len(rgb_image_data)}")
first_rgb = Image.open(io.BytesIO(rgb_image_data[0]))
print(f"RGB Image mode: {first_rgb.mode}")
print(f"RGB Image size: {first_rgb.size}")

In [None]:
# Visualize a grid of images
n_show = min(16, len(image_data))
cols = 4
rows = (n_show + cols - 1) // cols

fig, axes = plt.subplots(rows, cols, figsize=(12, 3*rows))
axes = axes.flatten() if n_show > 1 else [axes]

selected_random_image_indices = random.sample(range(len(image_data)), n_show)
for ax_idx, im_idx in enumerate(selected_random_image_indices):
    img = Image.open(io.BytesIO(image_data[im_idx])).convert('L')
    axes[ax_idx].imshow(np.array(img), cmap='gray')
    axes[ax_idx].set_title(f"View {im_idx}")
    axes[ax_idx].axis('off')

# Hide unused axes
for i in range(n_show, len(axes)):
    axes[i].axis('off')

plt.suptitle("SLAM Grayscale Images", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Object Point Projections

2D projections of 3D points onto each image view.

In [None]:
projections = pkl_sample["object_point_projections"]
print(f"Number of projection sets: {len(projections)}")
print(f"First projection shape: {projections[0].shape}")

# Visualize projections on images
fig, axes = plt.subplots(rows, cols, figsize=(12, 3*rows))
axes = axes.flatten() if n_show > 1 else [axes]
for ax_idx, im_idx in enumerate(selected_random_image_indices):
    img = Image.open(io.BytesIO(image_data[im_idx])).convert('L')
    axes[ax_idx].imshow(np.array(img), cmap='gray')

    # Plot projections
    uv = projections[im_idx].numpy()
    axes[ax_idx].scatter(uv[:, 0], uv[:, 1], s=1, c='lime', alpha=0.5)
    axes[ax_idx].set_title(f"View {ax_idx}: {len(uv)} points")
    axes[ax_idx].axis('off')

plt.suptitle("Point Projections on Images", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## 5. Camera Data

In [None]:
# Camera extrinsics (world-to-camera transforms)
Ts_camera = pkl_sample["Ts_camera_model"]
print(f"Camera extrinsics shape: {Ts_camera.shape}")
print(f"Number of camera poses: {len(Ts_camera)}")

# Extract camera centers (inverse of extrinsics)
camera_centers = []
for i in range(len(Ts_camera)):
    T_inv = np.linalg.inv(Ts_camera[i].numpy())
    camera_centers.append(T_inv[:3, 3])
camera_centers = np.array(camera_centers)

print(f"\nCamera center range:")
print(f"  X: [{camera_centers[:, 0].min():.3f}, {camera_centers[:, 0].max():.3f}]")
print(f"  Y: [{camera_centers[:, 1].min():.3f}, {camera_centers[:, 1].max():.3f}]")
print(f"  Z: [{camera_centers[:, 2].min():.3f}, {camera_centers[:, 2].max():.3f}]")

camera_params = pkl_sample["camera_params"]
print(f"\nCamera intrinsics: {len(camera_params)} matrices")
print(f"First intrinsic matrix:\n{camera_params[0]}")

In [None]:
# Subsample point cloud
idx = np.random.choice(len(points), min(2000, len(points)), replace=False)

fig = go.Figure()

# Plot point cloud
fig.add_trace(
  go.Scatter3d(
      x=points[idx, 0], y=points[idx, 1], z=points[idx, 2],
      mode='markers',
      marker=dict(size=1, color='blue', opacity=0.3),
      name='Point Cloud'
  )
)

# Plot camera positions
fig.add_trace(
  go.Scatter3d(
      x=camera_centers[:, 0], y=camera_centers[:, 1], z=camera_centers[:, 2],
      mode='markers',
      marker=dict(size=3, color='red', symbol='diamond'),
      name='Camera Positions'
  )
)

fig.update_layout(
  title="Camera Positions Around Object",
  scene=dict(
      xaxis_title='X',
      yaxis_title='Y',
      zaxis_title='Z',
      aspectmode='data',  # Equal aspect ratio based on data ranges
  ),
  height=600,
  legend=dict(x=0.8, y=0.9),
)

fig.show()

## 6. Ground Truth Mesh

In [None]:
mesh_verts = pkl_sample["mesh_vertices"].numpy()
mesh_faces = pkl_sample["mesh_faces"].numpy()

print(f"Mesh vertices: {mesh_verts.shape}")
print(f"Mesh faces: {mesh_faces.shape}")
print(f"Vertex range: [{mesh_verts.min():.3f}, {mesh_verts.max():.3f}]")

# Interactive mesh visualization
fig = go.Figure()

fig.add_trace(
  go.Mesh3d(
      x=mesh_verts[:, 0],
      y=mesh_verts[:, 1],
      z=mesh_verts[:, 2],
      i=mesh_faces[:, 0],
      j=mesh_faces[:, 1],
      k=mesh_faces[:, 2],
      color='gray',
      opacity=1.0,
      flatshading=False,
      lighting=dict(ambient=0.5, diffuse=0.8, specular=0.2),
      lightposition=dict(x=100, y=200, z=300),
      name='Ground Truth Mesh'
  )
)

fig.update_layout(
  title=f"Ground Truth Mesh ({len(mesh_verts)} vertices, {len(mesh_faces)} faces)",
  scene=dict(
      xaxis_title='X',
      yaxis_title='Y',
      zaxis_title='Z',
      aspectmode='data',
  ),
  height=700,
)

fig.show()

## 7. Caption / Category

In [None]:
if "caption" in pkl_sample:
    print(f"Caption: {pkl_sample['caption']}")
elif "category" in pkl_sample:
    print(f"Category: {pkl_sample['category']}")
else:
    print("No caption or category found")

# Check for device type flag
print(f"\nis_ariagen2: {pkl_sample['is_ariagen2']}")

---

## 8. Using the InferenceDataset

Now let's see how the  dataloader processes this data. This does a lot of heavy lifting for object centric reconstruction, such as selecting the views using a particular heuristic, cropping the image around the object, packing everything needed to perform reconstruction.

In [None]:
import omegaconf
from dataset.shaper_dataset import InferenceDataset

# Load config
config = omegaconf.OmegaConf.load("checkpoints/config.yaml")

# Create dataset with different view counts
num_views = 4  # Can be 4, 8, 16

dataset = InferenceDataset(
    config,
    paths=[PKL_PATH],
    override_num_views=num_views,
)

# Get a sample
sample = dataset[0]

print(f"Sample keys: {list(sample.keys())}")
print(f"\nSample details:")
for key, value in sample.items():
    if isinstance(value, np.ndarray):
        print(f"  {key}: np.ndarray, shape={value.shape}, dtype={value.dtype}")
    elif isinstance(value, str):
        print(f"  {key}: '{value[:50]}...'" if len(value) > 50 else f"  {key}: '{value}'")
    else:
        print(f"  {key}: {type(value).__name__}")

### 8.1 Visualize Dataloader Output (Non-RGB / SLAM Variant)

In [None]:
# Visualize processed images from dataloader
images = sample["images"]  # Shape: (N, C, H, W)
masks = sample.get("masks_ingest", None)

n_images = len(images)
cols = min(4, n_images)
rows = (n_images + cols - 1) // cols

fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
axes = axes.flatten() if n_images > 1 else [axes]

for i in range(n_images):
    # Images are (C, H, W), transpose to (H, W, C)
    img = images[i].transpose(1, 2, 0)
    if img.shape[2] == 1:
        img = img[:, :, 0]
        axes[i].imshow(img, cmap='gray')
    else:
        axes[i].imshow(img)
    axes[i].set_title(f"Processed View {i}")
    axes[i].axis('off')

for i in range(n_images, len(axes)):
    axes[i].axis('off')

plt.suptitle(f"Dataloader Output: {n_images} Views (SLAM/Non-RGB)", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# Show masks
if masks is not None:
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    axes = axes.flatten() if n_images > 1 else [axes]

    for i in range(n_images):
        axes[i].imshow(masks[i], cmap='gray')
        axes[i].set_title(f"Mask {i}")
        axes[i].axis('off')

    for i in range(n_images, len(axes)):
        axes[i].axis('off')

    plt.suptitle("Point Projection Masks", fontsize=14, fontweight='bold')
    plt.tight_layout()

    plt.show()

The rotation here is intentional, ShapeR model was trained on Aria gen 1 simulated images, which are rotated like this.

In [None]:
# Visualize processed point cloud
processed_points = sample["semi_dense_points"]

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

idx = np.random.choice(len(processed_points), min(5000, len(processed_points)), replace=False)
ax.scatter(processed_points[idx, 0], processed_points[idx, 1], processed_points[idx, 2], s=2, alpha=0.5)

ax.set_title(f"Processed Point Cloud (Filtered & Scaled)\n{len(processed_points)} points")
ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_zlim(-1, 1)
plt.show()

print(f"Scale factor: {sample['scale']:.4f}")
print(f"Caption: {sample['caption']}")

### 8.2 Visualize Dataloader Output (RGB Variant)

If RGB data is available, we can switch the image processor to use RGB images.

In [None]:
# Check if RGB data is available
# To use RGB, we need to modify the image processor call
# The key difference is is_rgb=True in get_image_data_based_on_strategy
from dataset.image_processor import get_image_data_based_on_strategy, crop_pad_preselected_views_with_background

scale = 0.9 / np.max(pkl_sample["bounds"].numpy())

# Get RGB images
(
    rectified_images_rgb,
    rectified_point_masks_rgb,
    rectified_camera_params_rgb,
    selected_view_ext_rgb,
) = get_image_data_based_on_strategy(
    pkl_sample,
    num_views=16,
    scale=scale,
    is_rgb=True,  # KEY DIFFERENCE: Use RGB
    strategy="cluster",
)

(
    selected_view_imgs_rgb,
    _,
    rectified_masks_rgb,
    _,
) = crop_pad_preselected_views_with_background(
    rectified_images_rgb,
    rectified_point_masks_rgb,
    rectified_camera_params_rgb,
    config.encoder.dino_image_size,
    add_point_locations=False,
)

# Visualize RGB images
n_rgb = len(selected_view_imgs_rgb)
fig, axes = plt.subplots(1, n_rgb, figsize=(4*n_rgb, 4))
if n_rgb == 1:
    axes = [axes]

for i in range(n_rgb):
    img = selected_view_imgs_rgb[i].transpose(1, 2, 0)
    axes[i].imshow(img)
    axes[i].set_title(f"RGB View {i}")
    axes[i].axis('off')

plt.suptitle("RGB Variant Images", fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

Note that semi-dense point tracking was performed on SLAM images only. While SLAM images have reliable object visibility information, object's visibility in the RGB image is estimated via heuristics and there's a very small chance that the frames do not have the object in view, due to occluders.

## 9. Using the DataLoader with Collate Function

In [None]:
# Create a DataLoader with the custom collate function
# The custom_collate handles:
# 1. Point cloud preprocessing into a SparseTensor for efficient 3D convolutions
# 2. Keeping vertices/faces as lists (variable size per sample)
# 3. Standard batching for everything else

inference_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,  # Can increase for multiple samples
    shuffle=False,
    drop_last=False,
    num_workers=0,
    collate_fn=dataset.custom_collate,  # KEY: Use custom collate function
)

# Get a batch
batch = next(iter(inference_loader))

print("Batch keys and types:")
print("=" * 60)
for key, value in batch.items():
    if hasattr(value, 'shape'):
        print(f"  {key}: {type(value).__name__}, shape={tuple(value.shape)}")
    elif isinstance(value, list):
        if len(value) > 0 and hasattr(value[0], 'shape'):
            print(f"  {key}: list[{len(value)}] of {type(value[0]).__name__}, first shape={value[0].shape}")
        else:
            print(f"  {key}: list[{len(value)}]")
    else:
        print(f"  {key}: {type(value).__name__}")

In [None]:
# The semi_dense_points is now a SparseTensor (from torchsparse)
# This is an efficient representation for 3D point data used by sparse convolutions
sparse_points = batch["semi_dense_points"]
print("SparseTensor structure:")
print(f"  Coordinates (C): {sparse_points.C.shape} - quantized 3D coordinates + batch index")
print(f"  Features (F): {sparse_points.F.shape} - point features (original xyz coordinates)")

# Move batch to device (GPU if available) with optional dtype conversion
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch = InferenceDataset.move_batch_to_device(batch, device, dtype=torch.bfloat16)

print(f"\nBatch moved to: {device}")
print(f"Images dtype: {batch['images'].dtype}")

### Key Points about the DataLoader

1. **Custom Collate Function**: The `custom_collate` function is essential because:
   - Point clouds are converted to `torchsparse.SparseTensor` for efficient sparse 3D convolutions
   - Mesh vertices/faces are kept as lists since they have variable sizes across samples
   
2. **SparseTensor**: The point cloud is quantized into discrete bins and stored as a sparse tensor with:
   - `C` (coordinates): Quantized (x, y, z, batch_idx) integers
   - `F` (features): Original continuous xyz coordinates
   
3. **Device Transfer**: Use `InferenceDataset.move_batch_to_device()` to properly move all tensors (including SparseTensors) to GPU with optional dtype conversion (e.g., bfloat16 for memory efficiency)

4. **View Selection Strategies**: The dataset supports different strategies for selecting which views to use:
   - `cluster`: K-means clustering on camera positions for diverse viewpoints (default)
   - `last_n`: Last N views in the capture sequence
   - `view_angle`: Hemisphere-based selection for even angular coverage