# Segmentation Tracking with DINOv3

This notebook demonstrates using DINOv3 for video segmentation tracking using a non-parametric method similar to "Space-time correspondence as a contrastive random walk" (Jabri et al. 2020).

Given:
- RGB video frames
- Instance segmentation masks for the first frame

We will extract patch features from each frame and use patch similarity to propagate the ground-truth labels to all frames.

## Setup
Let's start by loading some pre-requisites, setting up the environment and checking the DINOv3 repository location:

In [None]:
import datetime
import functools
import io
import logging
import math
import os
from pathlib import Path
import tarfile
import time
import urllib

import lovely_tensors
import matplotlib.pyplot as plt
import mediapy as mp
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms as TVT
import torchvision.transforms.functional as TVTF
from torch import Tensor, nn
from tqdm import tqdm

# Library Imports
from dinov3production.video.tracking import propagate, make_neighborhood_mask

DISPLAY_HEIGHT = 200
lovely_tensors.monkey_patch()
torch.set_grad_enabled(False)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

DINOV3_GITHUB_LOCATION = "facebookresearch/dinov3"

if os.getenv("DINOV3_LOCATION") is not None:
    DINOV3_LOCATION = os.getenv("DINOV3_LOCATION")
else:
    DINOV3_LOCATION = DINOV3_GITHUB_LOCATION

print(f"DINOv3 location set to {DINOV3_LOCATION}")

## Model
We load the DINOv3 ViT-L model and get some attributes.

In [None]:
# examples of available DINOv3 models:
MODEL_DINOV3_VITS = "dinov3_vits16"
MODEL_DINOV3_VITSP = "dinov3_vits16plus"
MODEL_DINOV3_VITB = "dinov3_vitb16"
MODEL_DINOV3_VITL = "dinov3_vitl16"
MODEL_DINOV3_VITHP = "dinov3_vith16plus"
MODEL_DINOV3_VIT7B = "dinov3_vit7b16"

# we take DINOv3 ViT-L
MODEL_NAME = MODEL_DINOV3_VITL

model = torch.hub.load(
    repo_or_dir=DINOV3_LOCATION,
    model=MODEL_NAME,
    source="local" if DINOV3_LOCATION != DINOV3_GITHUB_LOCATION else "github",
)
model.to("cuda")
model.eval()

patch_size = model.patch_size
embed_dim = model.embed_dim
print(f"Patch size: {patch_size}")
print(f"Embedding dimension: {embed_dim}")
print(f"Peak GPU memory: {torch.cuda.max_memory_allocated() / 2**30:.1f} GB")

We want to process one image at the time and get L2-normalized features. Here is a wrapper to do just that.

In [None]:
@torch.compile(disable=True)
def forward(
    model: nn.Module,
    img: Tensor,  # [3, H, W] already normalized for the model
) -> Tensor:
    feats = model.get_intermediate_layers(img.unsqueeze(0), n=1, reshape=True)[0]  # [1, D, h, w]
    feats = feats.movedim(-3, -1)  # [1, h, w, D]
    feats = F.normalize(feats, dim=-1, p=2)
    return feats.squeeze(0)  # [h, w, D]

## Data
Here we load the video frames and the instance segmentation masks for the first frame.

In [None]:
VIDEO_FRAMES_URI = "https://dl.fbaipublicfiles.com/dinov3/notebooks/segmentation_tracking/video_frames.tar.gz"

def load_video_frames_from_remote_tar(tar_uri: str) -> list[Image.Image]:
    images = []
    indices = []
    try:
        with urllib.request.urlopen(tar_uri) as f:
            tar = tarfile.open(fileobj=io.BytesIO(f.read()))
            for member in tar.getmembers():
                if member.name.lower().endswith(('.png', '.jpg')):
                    index_str, _ = os.path.splitext(os.path.basename(member.name))
                    image_data = tar.extractfile(member)
                    image = Image.open(image_data).convert("RGB")
                    images.append(image)
                    indices.append(int(index_str))
    except Exception as e:
        print(f"Warning: Failed to load video frames: {e}. Generating dummy frames.")
        # Fallback for offline/test
        return [Image.new('RGB', (1920, 1440), color='gray') for _ in range(10)]
        
    order = np.argsort(indices)
    return [images[i] for i in order]

frames = load_video_frames_from_remote_tar(VIDEO_FRAMES_URI)
num_frames = len(frames)
print(f"Number of frames: {num_frames}")

if num_frames > 0:
    original_width, original_height = frames[0].size
    print(f"Original size: width={original_width}, height={original_height}")

Let's show four sample frames from the video:

In [None]:
if num_frames > 0:
    num_selected_frames = 4
    selected_frames = np.linspace(0, num_frames - 1, num_selected_frames, dtype=int)

    mp.show_images(
        [frames[int(i)] for i in selected_frames],
        titles=[f"Frame {i}" for i in selected_frames],
        height=DISPLAY_HEIGHT,
    )

This notebook assumes that instance segmentation masks for the first frame are stored in a .png file:

In [None]:
def mask_to_rgb(mask: np.ndarray | Tensor, num_masks: int) -> np.ndarray:
    if isinstance(mask, Tensor):
        mask = mask.cpu().numpy()

    # Exclude background
    background = mask == 0
    mask = mask - 1
    num_masks = num_masks - 1

    # Choose palette
    if num_masks <= 10:
        mask_rgb = plt.get_cmap("tab10")(mask)[..., :3]
    elif num_masks <= 20:
        mask_rgb = plt.get_cmap("tab20")(mask)[..., :3]
    else:
        mask_rgb = plt.get_cmap("gist_rainbow")(mask / (num_masks - 1))[..., :3]

    mask_rgb = (mask_rgb * 255).astype(np.uint8)
    mask_rgb[background, :] = 0
    return mask_rgb


def load_image_from_url(url: str) -> Image:
    try:
        with urllib.request.urlopen(url) as f:
            return Image.open(f)
    except:
        return Image.new('L', (1920, 1440), color=0)


first_mask_np = np.array(
    load_image_from_url(
        "https://dl.fbaipublicfiles.com/dinov3/notebooks/segmentation_tracking/first_video_frame_mask.png"
    )
)

if first_mask_np.max() == 0 and num_frames > 0:
     # Dummy mask if load failed
     first_mask_np[100:500, 100:500] = 1
     first_mask_np[600:900, 600:900] = 2

mask_height, mask_width = first_mask_np.shape 
print(f"Mask size: {[mask_height, mask_width]}")

num_masks = int(first_mask_np.max() + 1)
print(f"Number of masks: {num_masks}")

if num_frames > 0:
    mp.show_images(
        [frames[0], mask_to_rgb(first_mask_np, num_masks)],
        titles=["Frame", "Mask"],
        height=DISPLAY_HEIGHT,
    )

## Transforms
Input frames need to be resized to match the desired forward resolution and the model patch size.

In [None]:
class ResizeToMultiple(nn.Module):
    def __init__(self, short_side: int, multiple: int):
        super().__init__()
        self.short_side = short_side
        self.multiple = multiple

    def _round_up(self, side: float) -> int:
        return math.ceil(side / self.multiple) * self.multiple

    def forward(self, img):
        old_width, old_height = TVTF.get_image_size(img)
        if old_width > old_height:
            new_height = self._round_up(self.short_side)
            new_width = self._round_up(old_width * new_height / old_height)
        else:
            new_width = self._round_up(self.short_side)
            new_height = self._round_up(old_height * new_width / old_width)
        return TVTF.resize(img, [new_height, new_width], interpolation=TVT.InterpolationMode.BICUBIC)


SHORT_SIDE = 960

transform = TVT.Compose(
    [
        ResizeToMultiple(short_side=SHORT_SIDE, multiple=patch_size),
        TVT.ToTensor(),
        TVT.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
if num_frames > 0:
    first_frame = transform(frames[0]).to("cuda")
    print(f"First frame: {first_frame}")

    _, frame_height, frame_width = first_frame.shape
    feats_height, feats_width = frame_height // patch_size, frame_width // patch_size

Label propagation happens at the output resolution of the model, so we downsample the ground-truth masks of the first frame and turn them into a one-hot probability map.

In [None]:
first_mask = torch.from_numpy(first_mask_np).to("cuda", dtype=torch.long)
first_mask = F.interpolate(
    first_mask[None, None, :, :].float(),
    (feats_height, feats_width),
    mode="nearest-exact",
)[0, 0].long()

first_probs = F.one_hot(first_mask, num_masks).float()
print(f"First mask shape: {first_mask.shape}")
print(f"First probs shape: {first_probs.shape}")

## How it works
Label propagation takes as input current features, context features, and probabilities, and computes similarity.

In [None]:
# Using library implementations for optimized propagation
from dinov3production.video.tracking import propagate, make_neighborhood_mask

Visualization of neighborhood mask:

In [None]:
neighborhood_mask = make_neighborhood_mask(feats_height, feats_width, size=12, shape="circle")

mp.show_images(
    {f"{(i, j)}": neighborhood_mask[i, j].cpu().numpy() for i, j in [[3, 14], [20, 25]]},
    height=DISPLAY_HEIGHT,
)

To understand how it works, let's do it for one frame only. The "context" contains only the first frame and the "current frame" is the second one.

In [None]:
if num_frames > 1:
    torch._dynamo.maybe_mark_dynamic(first_frame, (1, 2))
    first_feats = forward(model, first_frame)

    frame_idx = 1
    current_frame_pil = frames[frame_idx]
    current_frame = transform(current_frame_pil).to("cuda")
    torch._dynamo.maybe_mark_dynamic(current_frame, (1, 2))
    current_feats = forward(model, current_frame)

    current_probs = propagate(
        current_feats,
        context_features=first_feats.unsqueeze(0),
        context_probs=first_probs.unsqueeze(0),
        neighborhood_mask=neighborhood_mask,
        topk=5,
        temperature=0.2,
    )
    print(f"Current probs shape: {current_probs.shape}")

Then, we upsample the predicted probabilities and postprocess them.

In [None]:
def postprocess_probs(probs: Tensor) -> Tensor:
    vmin = probs.flatten(2, 3).min(dim=2).values
    vmax = probs.flatten(2, 3).max(dim=2).values
    probs = (probs - vmin[:, :, None, None]) / (vmax[:, :, None, None] - vmin[:, :, None, None])
    probs = torch.nan_to_num(probs, nan=0)
    return probs

if num_frames > 1:
    p = current_probs.movedim(-1, -3).unsqueeze(0)
    p = F.interpolate(p, size=(mask_height, mask_width), mode="nearest")
    p = postprocess_probs(p).squeeze(0)
    current_pred_np = p.argmax(0).cpu().numpy()
    current_probs_np = p.cpu().numpy()

    mp.show_images(
        [
            frames[0],
            current_frame_pil,
            mask_to_rgb(first_mask_np, num_masks),
            mask_to_rgb(current_pred_np, num_masks),
        ],
        titles=["First frame", "Second frame", "", ""],
        columns=2,
        height=DISPLAY_HEIGHT,
    )

## Process Video
Process all frames with context queue.

In [None]:
MAX_CONTEXT_LENGTH = 7
NEIGHBORHOOD_SIZE = 12
NEIGHBORHOOD_SHAPE = "circle"
TOPK = 5
TEMPERATURE = 0.2

mask_predictions = torch.zeros([num_frames, mask_height, mask_width], dtype=torch.uint8)
mask_predictions[0, :, :] = torch.from_numpy(first_mask_np)

mask_probabilities = torch.zeros([num_frames, num_masks, mask_height, mask_width])
mask_probabilities[0, :, :, :] = F.one_hot(torch.from_numpy(first_mask_np).long(), num_masks).movedim(-1, -3)

features_queue: list[Tensor] = []
probs_queue: list[Tensor] = []

neighborhood_mask = make_neighborhood_mask(
    feats_height,
    feats_width,
    size=NEIGHBORHOOD_SIZE,
    shape=NEIGHBORHOOD_SHAPE,
)

if num_frames > 1:
    start = time.perf_counter()
    for frame_idx in tqdm(range(1, num_frames), desc="Processing"):
        current_frame_pil = frames[frame_idx]
        current_frame = transform(current_frame_pil).to("cuda")
        torch._dynamo.maybe_mark_dynamic(current_frame, (1, 2))
        current_feats = forward(model, current_frame)

        context_feats = torch.stack([first_feats, *features_queue], dim=0)
        context_probs = torch.stack([first_probs, *probs_queue], dim=0)
        torch._dynamo.maybe_mark_dynamic(context_feats, 0)
        torch._dynamo.maybe_mark_dynamic(context_probs, (0, 3))

        current_probs = propagate(
            current_feats,
            context_feats,
            context_probs,
            neighborhood_mask,
            TOPK,
            TEMPERATURE,
        )

        features_queue.append(current_feats)
        probs_queue.append(current_probs)
        if len(features_queue) > MAX_CONTEXT_LENGTH:
            features_queue.pop(0)
        if len(probs_queue) > MAX_CONTEXT_LENGTH:
            probs_queue.pop(0)

        current_probs = F.interpolate(
            current_probs.movedim(-1, -3)[None, :, :, :],
            size=(mask_height, mask_width),
            mode="nearest",
        )
        current_probs = postprocess_probs(current_probs)
        current_probs = current_probs.squeeze(0)
        mask_probabilities[frame_idx, :, :, :] = current_probs
        pred = torch.argmax(current_probs, dim=0).to(dtype=torch.uint8)
        mask_predictions[frame_idx, :, :] = pred

    torch.cuda.synchronize()
    end = time.perf_counter()
    print(f"Processing time: {datetime.timedelta(seconds=round(end - start))}")

Let's visualize a few frames and a video of the result.

In [None]:
if num_frames > 0:
    mp.show_images(
        [frames[i].convert("RGB") for i in selected_frames]
        + [mask_to_rgb(mask_predictions[i], num_masks) for i in selected_frames],
        titles=[f"Frame {i}" for i in selected_frames] + [""] * len(selected_frames),
        columns=len(selected_frames),
        height=DISPLAY_HEIGHT,
    )

    mp.show_videos(
        {
            "Input": [np.array(frame) for frame in frames],
            "Pred": mask_to_rgb(mask_predictions, num_masks),
        },
        height=DISPLAY_HEIGHT,
        fps=24,
    )

## Conclusion
This notebook showed how to use DINOv3 for video segmentation tracking. It should be fairly straightforward to run it to your own video and masks. The notebook hyperparameters can also be adjusted to see the effect on the results.