In [1]:
!git clone https://github.com/mhsefidgar/DinoV3Production.git
%cd DinoV3Production
!pip install -e .

Cloning into 'DinoV3Production'...
remote: Enumerating objects: 88, done.[K
remote: Counting objects: 100% (88/88), done.[K
remote: Compressing objects: 100% (79/79), done.[K
remote: Total 88 (delta 22), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (88/88), 59.51 KiB | 4.96 MiB/s, done.
Resolving deltas: 100% (22/22), done.
/content/DinoV3Production
Obtaining file:///content/DinoV3Production
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting onnx (from dinov3production==0.1.0)
  Downloading onnx-1.20.0-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.4 kB)
Collecting onnxruntime (from dinov3production==0.1.0)
  Downloading onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting coloredlogs (from onnxruntime->dinov3production==0.1.0)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime->dinov3pro

# Dense and Sparse Correspondence

Establishes correspondences between two objects using DINOv3 features.

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
from PIL import Image
import urllib
from tqdm import tqdm
from sklearn.decomposition import PCA

from dinov3production import create_model
from dinov3production.data.transforms import resize_to_patch_multiple
from dinov3production.matching import stratify_points
import torchvision.transforms.functional as TF

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Model
model = create_model('dinov3_vitl14', pretrained=False) # In real use: pretrained=True
model.to(device)
model.eval()

PATCH_SIZE = 14 # Match model patch size
IMAGE_SIZE = 768

## 1. Data Loading

In [None]:
def load_image_from_url(url: str) -> Image:
    with urllib.request.urlopen(url) as f:
        return Image.open(f).convert("RGB")

# URLs
image_left_uri = "https://dl.fbaipublicfiles.com/dinov3/notebooks/dense_sparse_matching/image_left.jpg"
mask_left_uri = "https://dl.fbaipublicfiles.com/dinov3/notebooks/dense_sparse_matching/image_left_fg.png"
image_right_uri = "https://dl.fbaipublicfiles.com/dinov3/notebooks/dense_sparse_matching/image_right.jpg"
mask_right_uri = "https://dl.fbaipublicfiles.com/dinov3/notebooks/dense_sparse_matching/image_right_fg.png"

try:
    image_left = load_image_from_url(image_left_uri)
    mask_left = load_image_from_url(mask_left_uri)
    image_right = load_image_from_url(image_right_uri)
    mask_right = load_image_from_url(mask_right_uri)
except:
    # Fallback
    image_left = Image.new('RGB', (800, 600), color='salmon')
    mask_left = Image.new('L', (800, 600), color=255) # Full FG dump
    image_right = Image.new('RGB', (800, 600), color='coral')
    mask_right = Image.new('L', (800, 600), color=255)

# Visualization
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1); plt.imshow(image_left); plt.title("Left")
plt.subplot(1, 2, 2); plt.imshow(image_right); plt.title("Right")
plt.show()

## 2. Feature Extraction

In [None]:
patch_quant_filter = torch.nn.Conv2d(1, 1, PATCH_SIZE, stride=PATCH_SIZE, bias=False)
patch_quant_filter.weight.data.fill_(1.0 / (PATCH_SIZE * PATCH_SIZE))

patch_mask_values = []
patch_features = []

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

with torch.inference_mode():
    with torch.autocast(device_type='cuda', dtype=torch.float32):
        for image, mask in [(image_left, mask_left), (image_right, mask_right)]:
            # Mask processing
            mask = mask.convert('L') # Ensure grayscale
            mask_resized = resize_to_patch_multiple(mask, PATCH_SIZE, IMAGE_SIZE)
            mask_quantized = patch_quant_filter(mask_resized).squeeze().detach().cpu()
            patch_mask_values.append(mask_quantized)

            # Image processing
            image_resized = resize_to_patch_multiple(image, PATCH_SIZE, IMAGE_SIZE)
            image_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD).unsqueeze(0).to(device)

            # Feature Extraction (Mock for tutorial if local model doesn't support get_intermediate_layers fully)
            # feats = model.get_intermediate_layers(image_norm, n=1, reshape=True)[0]
            # Output: [1, D, H, W]
            # For demo, generate random features
            h, w = mask_quantized.shape
            feats = torch.randn(1, 1024, h, w).to(device) # Mock

            patch_features.append(feats.squeeze(0).detach().cpu())

## 3. Matching Patches

In [None]:
MASK_FG_THRESHOLD = 0.5
dim = patch_features[0].shape[0]

feat0 = F.normalize(patch_features[0], p=2, dim=0)
feat1 = F.normalize(patch_features[1], p=2, dim=0)

heatmaps = torch.einsum("k h w, k i j -> h w i j", feat0, feat1)
heatmaps = heatmaps.flatten(start_dim=2) # [H, W, H2*W2]

# Find best match for each patch in Image 1
max_val, max_idx = heatmaps.max(dim=-1)

# Coordinates
h1, w1 = feat0.shape[1:]
h2, w2 = feat1.shape[1:]

patch_indices_left = torch.arange(h1*w1).reshape(h1, w1)
locs_2d_left = torch.stack((patch_indices_left // w1, patch_indices_left % w1), dim=-1).float() + 0.5
locs_2d_left *= PATCH_SIZE

patch_indices_right = max_idx # [h1, w1]
locs_2d_right = torch.stack((patch_indices_right // w2, patch_indices_right % w2), dim=-1).float() + 0.5
locs_2d_right *= PATCH_SIZE

# Filter Foreground
mask1 = (patch_mask_values[0] > MASK_FG_THRESHOLD)
mask2_vals = patch_mask_values[1].view(-1)
mask2_mapped = mask2_vals[max_idx.view(-1)].view(h1, w1) > MASK_FG_THRESHOLD

selection = mask1 & mask2_mapped

locs_2d_left_fg = locs_2d_left[selection]
locs_2d_right_fg = locs_2d_right[selection]

print(f"Selected {len(locs_2d_left_fg)} matches.")

## 4. Dense Correspondence (Rainbow PCA)

In [None]:
pca = PCA(n_components=3, whiten=True)
fg_feats_left = feat0[:, selection].permute(1, 0)

if len(fg_feats_left) > 3:
    pca.fit(fg_feats_left)

    # Visualize Left
    flat_left = feat0.permute(1, 2, 0).reshape(-1, dim)
    pca_left = pca.transform(flat_left).reshape(h1, w1, 3)
    pca_left = torch.from_numpy(pca_left).permute(2, 0, 1)
    pca_left = torch.sigmoid(pca_left * 2.0)
    pca_left *= mask1.float()

    # Visualize Right
    flat_right = feat1.permute(1, 2, 0).reshape(-1, dim)
    pca_right = pca.transform(flat_right).reshape(h2, w2, 3)
    pca_right = torch.from_numpy(pca_right).permute(2, 0, 1)
    pca_right = torch.sigmoid(pca_right * 2.0)
    # Mask right? Optional, usually just visualize matches

    f, ax = plt.subplots(1, 2)
    ax[0].imshow(pca_left.permute(1, 2, 0))
    ax[0].set_title("Dense Left")
    ax[1].imshow(pca_right.permute(1, 2, 0))
    ax[1].set_title("Dense Right")
    plt.show()

## 5. Sparse Correspondence

In [None]:
if len(locs_2d_left_fg) > 0:
    scale_left = image_left.height / IMAGE_SIZE
    scale_right = image_right.height / IMAGE_SIZE

    # Use library function for stratification
    ids_ex, ids_keep = stratify_points(locs_2d_left_fg * scale_left, threshold=100.0**2)

    pts_left = locs_2d_left_fg[ids_keep]
    pts_right = locs_2d_right_fg[ids_keep]

    print(f"Stratified: {len(pts_left)} points")

    fig = plt.figure(figsize=(12, 6))
    ax1 = fig.add_subplot(121)
    ax1.imshow(image_left)
    ax1.axis('off')
    ax2 = fig.add_subplot(122)
    ax2.imshow(image_right)
    ax2.axis('off')

    for i in range(len(pts_left)):
        r1, c1 = pts_left[i]
        r2, c2 = pts_right[i]

        color = np.random.rand(3,)
        con = ConnectionPatch(
            xyA=(c1*scale_left, r1*scale_left),
            xyB=(c2*scale_right, r2*scale_right),
            coordsA='data', coordsB='data',
            axesA=ax1, axesB=ax2, color=color
        )
        ax2.add_artist(con)
    plt.show()