In [1]:
from upsamplers import load_loftup_checkpoint, norm, unnorm
from featurizers import get_featurizer
from sklearn.decomposition import PCA

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np
import cupy as cp
from skimage import color
from skimage.color import lch2lab, lab2rgb
from PIL import Image

from tqdm import tqdm
import time
import math


In [2]:
import rerun as rr
rr.init("loftup_comparison", recording_id="loftup_comparison")
current_time = time.strftime("%Y%m%d_%H%M%S")
rr.save(f"/root/repos/vlmaps/data/allmend_trail_recording_2025_12_06_full_trail_zed_sdk/rerun/loftup/loftup_comparison_{current_time}.rrd")

In [2]:
class TorchPCA(object):

    def __init__(self, n_components):
        self.n_components = n_components

    def fit(self, X):
        self.mean_ = X.mean(dim=0)
        unbiased = X - self.mean_.unsqueeze(0)
        U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
        self.components_ = V.T
        self.singular_values_ = S

        ## --- FIX STARTS HERE --- ##
        # Enforce a deterministic sign for each component to prevent color flipping.
        for i in range(self.n_components):
            # Find the element with the largest absolute value in the component vector
            max_abs_idx = torch.argmax(torch.abs(self.components_[i]))
            # If that element is negative, flip the entire component vector
            if self.components_[i, max_abs_idx] < 0:
                self.components_[i] *= -1
        ## --- FIX ENDS HERE --- ##
        
        return self

    def transform(self, X):
        t0 = X - self.mean_.unsqueeze(0)
        projected = t0 @ self.components_.T
        return projected

In [3]:
def min_max_scale(tensor, feature_range):
    """Scales a tensor to a given feature range."""
    tensor_min = tensor.min()
    tensor_max = tensor.max()
    if tensor_max == tensor_min:
        return torch.full_like(tensor, (feature_range[0] + feature_range[1]) / 2)
    scaled = (tensor - tensor_min) / (tensor_max - tensor_min)
    return scaled * (feature_range[1] - feature_range[0]) + feature_range[0]

def pca(image_feats_list, dim=3, fit_pca=None, use_torch_pca=True, max_samples=None, use_lch=True):
    device = image_feats_list[0].device

    def flatten(tensor, target_size=None):
        if len(tensor.shape) == 2:
            return tensor.detach().cpu()
        if target_size is not None and fit_pca is None:
            tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear")
        B, C, H, W = tensor.shape
        return tensor.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()

    if len(image_feats_list) > 1 and fit_pca is None:
        if len(image_feats_list[0].shape) == 2:
            target_size = None
        else:
            target_size = image_feats_list[0].shape[2]
    else:
        target_size = None

    flattened_feats = []
    for feats in image_feats_list:
        flattened_feats.append(flatten(feats, target_size))
    x = torch.cat(flattened_feats, dim=0)

    # Subsample the data if max_samples is set and the number of samples exceeds max_samples
    if max_samples is not None and x.shape[0] > max_samples:
        indices = torch.randperm(x.shape[0])[:max_samples]
        x = x[indices]

    if fit_pca is None:
        if use_torch_pca:
            fit_pca = TorchPCA(n_components=dim).fit(x)
        else:
            fit_pca = PCA(n_components=dim).fit(x)

    reduced_feats = []
    for feats in image_feats_list:
        x_red = fit_pca.transform(flatten(feats))
        if isinstance(x_red, np.ndarray):
            x_red = torch.from_numpy(x_red).float()

        if len(feats.shape) == 2:
            # For 1D features like CLS token, standard normalization is fine
            x_red -= x_red.min(dim=0, keepdim=True).values
            x_red /= x_red.max(dim=0, keepdim=True).values
            reduced_feats.append(x_red)
            continue

        # For 2D spatial features
        B, C, H, W = feats.shape
        
        if use_lch:
            # 1. Map each PCA component to an LCh channel by scaling to its valid range
            # PC1 -> Lightness (L): range [0, 100]
            l_channel = min_max_scale(x_red[:, 0], feature_range=(0, 100))
            # PC2 -> Chroma (C): range [0, 100] (practical range for colorfulness)
            c_channel = min_max_scale(x_red[:, 1], feature_range=(0, 100))
            # PC3 -> Hue (h): range [0, 360] (angle for color)
            h_channel = min_max_scale(x_red[:, 2], feature_range=(0, 360))

            # 2. Stack channels and reshape to image dimensions
            lch_image = torch.stack([l_channel, c_channel, h_channel], dim=-1)
            lch_image_np = lch_image.reshape(B, H, W, dim).cpu().numpy().squeeze() # Remove batch dim

            # 3. Convert LCh to RGB
            lab_image_np = lch2lab(lch_image_np)
            rgb_image_np = lab2rgb(lab_image_np)

            # 4. Convert back to a PyTorch tensor in [B, C, H, W] format
            final_image = torch.from_numpy(rgb_image_np).permute(2, 0, 1).unsqueeze(0).to(device)

        else: # Original RGB mapping
            x_red -= x_red.min(dim=0, keepdim=True).values
            x_red /= x_red.max(dim=0, keepdim=True).values
            final_image = x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)

        reduced_feats.append(final_image)

    return reduced_feats, fit_pca

In [4]:
def gaussian_blur(features, kernel_size=3, sigma=1):
    # features: [B, C, H, W]
    B, C, H, W = features.shape
    # Create Gaussian kernel
    import math
    def get_gaussian_kernel(kernel_size, sigma):
        ax = torch.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
        xx, yy = torch.meshgrid(ax, ax, indexing='ij')
        kernel = torch.exp(-(xx**2 + yy**2) / (2. * sigma**2))
        kernel = kernel / torch.sum(kernel)
        return kernel

    kernel = get_gaussian_kernel(kernel_size, sigma).to(features.device)
    kernel = kernel.expand(C, 1, kernel_size, kernel_size)
    # Pad input
    padding = kernel_size // 2
    blurred = F.conv2d(features, kernel, padding=padding, groups=C)
    return blurred

def blend_features(features, kernel_size=3, sigma=1):
    blurred = gaussian_blur(features, kernel_size, sigma)
    return 0.5 * features + 0.5 * blurred

In [14]:
def get_model(base_name="dinov2s_reg"):
    featurizer_class = base_name
    model, patch_size, dim = get_featurizer(featurizer_class)
    model = model.to('cuda').eval()
    kernel_size = patch_size 
    lr_size = 224 // patch_size
    load_size = 224

    torch_hub_name = f"loftup_{base_name}"
    upsampler = torch.hub.load('andrehuang/loftup', torch_hub_name, pretrained=True)
    upsampler = upsampler.to('cuda')

    return model, upsampler, load_size

In [6]:
model_list = ["dinov2s_reg"] # "dinov2s_reg", "dinov2b_reg", "clip", "siglip2"

In [7]:
def run_model_sliding_window(
    model,
    upsampler,
    img_path,
    crop_size=224,
    stride_rate=2/3,
    batch_size=8,
    max_image_size=518,
    device='cuda'
):
    """
    Runs a model on a high-resolution image by processing it in overlapping
    patches and averaging the results, with memory optimization.

    Args:
        model: The feature extraction model (e.g., ViT).
        upsampler: The model used to upsample low-resolution features.
        img_path (str): Path to the high-resolution input image.
        crop_size (int): The input size required by the model.
        stride_rate (float): The overlap between patches. 2/3 means 1/3 overlap.
        batch_size (int): The number of patches to process in a single batch to control VRAM usage.
        max_image_size (int, optional): If set, resizes the image's longest side to this value
                                        before processing. Defaults to None.
        device (str): The device to run the model on.

    Returns:
        torch.Tensor: The final high-resolution feature map for the entire image.
        torch.Tensor: The normalized high-resolution image tensor.
    """
    # 1. Load and normalize the high-resolution image
    # norm = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    transform = T.Compose([T.ToTensor(), norm])
    
    img = Image.open(img_path).convert("RGB")

    # --- NEW: Optional image resizing to reduce total number of patches ---
    if max_image_size is not None:
        original_size = img.size
        if max(original_size) > max_image_size:
            # Calculate new size while preserving aspect ratio
            if original_size[0] > original_size[1]: # Landscape
                new_w = max_image_size
                new_h = int(max_image_size * original_size[1] / original_size[0])
            else: # Portrait or square
                new_h = max_image_size
                new_w = int(max_image_size * original_size[0] / original_size[1])
            
            print(f"Image resized from {original_size} to ({new_w}, {new_h})")
            img = img.resize((new_w, new_h), Image.Resampling.LANCZOS)

    img_tensor = transform(img).unsqueeze(0).to(device)
    _, _, h, w = img_tensor.shape

    # 2. Calculate patch grid and stride
    stride = int(crop_size * stride_rate)
    h_grids = int(math.ceil(1.0 * (h - crop_size) / stride)) + 1
    w_grids = int(math.ceil(1.0 * (w - crop_size) / stride)) + 1

    # 3. Create lists to hold patches and their positions
    crops = []
    positions = []
    for i in range(h_grids):
        for j in range(w_grids):
            h0 = i * stride
            w0 = j * stride
            h1 = min(h0 + crop_size, h)
            w1 = min(w0 + crop_size, w)
            
            crop = img_tensor[:, :, h0:h1, w0:w1]
            padded_crop = F.pad(crop, (0, crop_size - (w1 - w0), 0, crop_size - (h1 - h0)))
            
            crops.append(padded_crop)
            positions.append((h0, h1, w0, w1))

    print("h_grids, w_grids:", h_grids, w_grids)

    # 4. --- NEW: Process crops in mini-batches to conserve memory ---
    hr_feats_list = []
    for i in range(0, len(crops), batch_size):
        # Get a mini-batch of crops
        batch_crops = torch.cat(crops[i:i + batch_size], dim=0)

        with torch.no_grad():
            lr_feats_batch = model(batch_crops)
            hr_feats_batch = upsampler(lr_feats_batch, batch_crops)
        
        # Move results to CPU to free up VRAM for the next batch
        hr_feats_list.append(hr_feats_batch.cpu())

    # Concatenate all results from mini-batches and move back to target device
    hr_feats_batch = torch.cat(hr_feats_list, dim=0).to(device)

    # 5. Stitch the results back together by averaging overlaps
    feat_dim = hr_feats_batch.shape[1]
    final_feats = torch.zeros(1, feat_dim, h, w, device=device)
    count_norm = torch.zeros(1, 1, h, w, device=device)

    for i, (h0, h1, w0, w1) in enumerate(positions):
        feat_patch = hr_feats_batch[i].unsqueeze(0)
        unpadded_feat_patch = feat_patch[:, :, :h1-h0, :w1-w0]
        final_feats[:, :, h0:h1, w0:w1] += unpadded_feat_patch
        count_norm[:, :, h0:h1, w0:w1] += 1

    final_feats /= (count_norm + 1e-8)
    return final_feats, img_tensor.squeeze(0)

In [8]:
def run_model(model, upsampler, img_path, load_size):
    transform = T.Compose([
        T.Resize(load_size, T.InterpolationMode.BILINEAR),
        T.CenterCrop(load_size), # Depending on whether you want a center crop
        T.ToTensor(),
        norm])
    img = Image.open(img_path).convert("RGB")
    img_transformed = transform(img)
    normalized_img_tensor = img_transformed.unsqueeze(0).to('cuda')

    lr_feats = model(normalized_img_tensor) # 1, dim, lr_size, lr_size
    hr_feats = upsampler(lr_feats, normalized_img_tensor) # 1, dim, 224, 224

    return hr_feats, img_transformed

In [9]:
def get_frame_list(frame_ranges):
    frames = []
    for range_dict in frame_ranges:
        start = range_dict["start"]
        end = range_dict["end"]
        step = range_dict["step"]
        frames.extend(range(start, end + 1, step))
    return sorted(list(set(frames)))  # Remove duplicates and sort

In [10]:
frame_ranges = [{"start": 400, "end": 800, "step": 10}]
rgb_dir = "/root/repos/vlmaps/data/allmend_trail_recording_2025_12_06_full_trail_zed_sdk/images"
frame_list = get_frame_list(frame_ranges)

In [15]:
for model_name in model_list:
    model, upsampler, load_size = get_model(model_name)

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /root/.cache/torch/hub/andrehuang_loftup_main
  checkpoint = torch.load(model_path, map_location="cpu")


In [20]:
for model_name in model_list:
    model, upsampler, load_size = get_model(model_name)
    for frame_idx in tqdm(frame_list, desc="Processing frames"):
        rr.set_time(timeline="frame", sequence=frame_idx)
        img_path = rgb_dir + f"/{frame_idx:05d}.png"
        
        hr_feats, img_transformed = run_model_sliding_window(model, upsampler, img_path, load_size)

        blended_feats = blend_features(hr_feats, kernel_size=3, sigma=1)
        pca_feats = pca([blended_feats], use_lch=False)[0][0][0]
        # pca_feats_cp = cp.fromDlpack(pca_feats.__dlpack__())
        # rgb_feats = pca_to_rgb(pca_feats_cp)
        # pca_hr_feats = pca([hr_feats])[0][0]
        rr.log(f"img/pca_{model_name}", rr.Image(pca_feats.permute(1,2,0)))
        rr.log("img/img", rr.Image(img_transformed.permute(1,2,0)))

Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /root/.cache/torch/hub/andrehuang_loftup_main
  checkpoint = torch.load(model_path, map_location="cpu")
Processing frames:   0%|          | 0/41 [00:00<?, ?it/s]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:   2%|▏         | 1/41 [00:01<01:04,  1.62s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:   5%|▍         | 2/41 [00:03<01:02,  1.61s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:   7%|▋         | 3/41 [00:04<01:00,  1.58s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  10%|▉         | 4/41 [00:06<00:57,  1.57s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  12%|█▏        | 5/41 [00:07<00:56,  1.57s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  15%|█▍        | 6/41 [00:09<00:55,  1.60s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  17%|█▋        | 7/41 [00:11<00:57,  1.68s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  20%|█▉        | 8/41 [00:13<00:56,  1.72s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  22%|██▏       | 9/41 [00:14<00:54,  1.71s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  24%|██▍       | 10/41 [00:16<00:52,  1.70s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  27%|██▋       | 11/41 [00:18<00:49,  1.66s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  29%|██▉       | 12/41 [00:19<00:48,  1.66s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  32%|███▏      | 13/41 [00:21<00:46,  1.65s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  34%|███▍      | 14/41 [00:23<00:45,  1.67s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  37%|███▋      | 15/41 [00:24<00:43,  1.67s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  39%|███▉      | 16/41 [00:26<00:42,  1.68s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  41%|████▏     | 17/41 [00:28<00:41,  1.72s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  44%|████▍     | 18/41 [00:30<00:39,  1.74s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  46%|████▋     | 19/41 [00:32<00:39,  1.81s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  49%|████▉     | 20/41 [00:33<00:38,  1.84s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  51%|█████     | 21/41 [00:35<00:36,  1.85s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  54%|█████▎    | 22/41 [00:37<00:35,  1.86s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  56%|█████▌    | 23/41 [00:39<00:32,  1.83s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  59%|█████▊    | 24/41 [00:41<00:31,  1.85s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  61%|██████    | 25/41 [00:43<00:29,  1.87s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  63%|██████▎   | 26/41 [00:44<00:27,  1.81s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  66%|██████▌   | 27/41 [00:46<00:25,  1.82s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  68%|██████▊   | 28/41 [00:48<00:23,  1.84s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  71%|███████   | 29/41 [00:50<00:21,  1.79s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  73%|███████▎  | 30/41 [00:52<00:19,  1.79s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  76%|███████▌  | 31/41 [00:54<00:18,  1.83s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  78%|███████▊  | 32/41 [00:55<00:16,  1.83s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  80%|████████  | 33/41 [00:57<00:14,  1.77s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  83%|████████▎ | 34/41 [00:59<00:12,  1.72s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  85%|████████▌ | 35/41 [01:00<00:10,  1.69s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  88%|████████▊ | 36/41 [01:02<00:08,  1.68s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  90%|█████████ | 37/41 [01:04<00:06,  1.68s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  93%|█████████▎| 38/41 [01:05<00:04,  1.67s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  95%|█████████▌| 39/41 [01:07<00:03,  1.64s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames:  98%|█████████▊| 40/41 [01:09<00:01,  1.65s/it]

Image resized from (1280, 720) to (518, 291)
h_grids, w_grids: 2 3


Processing frames: 100%|██████████| 41/41 [01:10<00:00,  1.72s/it]


---------------------------------------------------------

## Calculate Dinov2 Features of CaT - 

--------------------

## Try to Index SigLip2 Feaures

In [None]:
CLASS_COLORS = {
    "road": (160, 82, 45),  # Saddle Brown
    "vegetation": (133, 255, 48),  # Green,
    "sky": (31, 119, 180),  # Steel Blue
    "gravel": (227, 119, 194),  # Orchid Pink
    "rocks": (127, 127, 127),  # Gray
    "mud": (255, 127, 14),  # Dark Orange
    "person": (23, 190, 207),  # Cyan Blue
    "other": (158, 218, 229),  # Light Cyan
}

In [None]:
text_template = "a photo of a {label}"
text_inputs = [text_template.format(label=label) for label in CLASS_COLORS.keys()]
print(text_inputs)
text_features = model.forward_text(text_inputs)
text_features.shape

hr_feats_perm = hr_feats[0].permute(1,2,0)
H, W, D = hr_feats_perm.shape
hr_feats_flat = hr_feats_perm.reshape(-1, D)

scores = hr_feats_flat @ text_features.T
class_indices = scores.argmax(axis=1)
class_indices = class_indices.reshape(H, W)

In [None]:
!uv pip install "rerun-sdk[notebook]"
import rerun as rr
rr.init("rerun_example_notebook")

In [None]:
labels_mapping = [
        rr.AnnotationInfo(id=i, label=label, color=list(CLASS_COLORS[label]))
        for i, label in enumerate(CLASS_COLORS.keys())
    ]
labels_mapping.append(rr.AnnotationInfo(id=len(CLASS_COLORS), label="background", color=[0, 0, 0]))
rr.log("img", rr.AnnotationContext(labels_mapping), static=True)
rr.log("img/seg_image", rr.SegmentationImage(class_indices))