# Training a Foreground Segmentation Tool with DINOv3

In this tutorial, we will train a linear foreground segmentation model using DINOv3 features. We use real sample images.

In [None]:
import io
import os
import pickle
import tarfile
import urllib

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.linear_model import LogisticRegression

import torch
import torchvision.transforms.functional as TF
from tqdm import tqdm

# Library Imports
from dinov3production import create_model
from dinov3production.data.transforms import resize_to_patch_multiple, quantize_mask

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load Model
# We use ViT-S/14 for Colab stability (low memory usage).
# Switch to 'dinov3_vitl14' (Large) if you have enough GPU RAM.
model = create_model('dinov3_vits14', pretrained=True)
model.to(device)
model.eval()

PATCH_SIZE = 14 # Aligned with model architecture
IMAGE_SIZE = 768

## 1. Data Loading
We load real images from the web. Since we need ground truth masks for training, we will use a small set of images where we can approximate the foreground or download corresponding masks if available. For this tutorial, we visualize the images.

In [None]:
def load_image_from_url(url: str) -> Image:
    with urllib.request.urlopen(url) as f:
        return Image.open(f).convert("RGB")

# Define a small training set of Real Images
DATA_URLS = [
    "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/600px-American_Eskimo_Dog.jpg",
    "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c3/Aurora_as_seen_from_Fairbanks_Alaska.jpg/600px-Aurora_as_seen_from_Fairbanks_Alaska.jpg",
    "https://upload.wikimedia.org/wikipedia/commons/thumb/2/21/Mandel_zoom_00_mandelbrot_set.jpg/600px-Mandel_zoom_00_mandelbrot_set.jpg",
]

images = []
labels = []

print("Downloading real samples...")
for i, url in enumerate(DATA_URLS):
    try:
        img = load_image_from_url(url)
        images.append(img)
        
        # Generate a 'Weak Label' mask for tutorial purposes
        # In a real app, you would have hand-labeled masks (black/white PNGs)
        # Here we assume central object bias for the Dog/Mandelbrot, or just a dummy center box
        w, h = img.size
        mask = Image.new('L', (w, h), 0)
        # Draw a box in the center as 'foreground' label
        margin_w, margin_h = int(w*0.25), int(h*0.25)
        for y in range(margin_h, h - margin_h):
            for x in range(margin_w, w - margin_w):
                mask.putpixel((x, y), 255)
                
        labels.append(mask)
    except Exception as e:
        print(f"Failed to download {url}: {e}")

n_images = len(images)
print(f"Loaded {n_images} images.")

# Visualize one example
if n_images > 0:
    plt.subplot(1, 2, 1); plt.imshow(images[0]); plt.title("Image")
    plt.subplot(1, 2, 2); plt.imshow(labels[0]); plt.title("Weak Label Mask")
    plt.show()

## 2. Feature Extraction & Label Building
Resize images/masks to patch grid, quantized mask, extract features.

In [None]:
xs = []
ys = []
image_index = []

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def extract_dino_features(model, x):
    # x: [B, 3, H, W]
    B, _, H, W = x.shape
    with torch.cuda.amp.autocast(enabled=True):
        out = model.forward_features(x)
    
    # Get number of register tokens
    n_reg = getattr(model, 'num_register_tokens', 0)
    
    # Slice: [CLS, REG_1...REG_K, PATCHES...]
    patch_tokens = out[:, 1+n_reg:]
    
    h = H // PATCH_SIZE
    w = W // PATCH_SIZE
    feats = patch_tokens.reshape(B, h, w, -1)
    return feats.permute(0, 3, 1, 2) # [B, D, H, W]

with torch.inference_mode():
     for i in tqdm(range(n_images), desc="Processing images"):
         # Process Label: Resize & Quantize
         mask_i = labels[i].split()[-1] # Extract alpha/BW channel
         mask_i_resized = resize_to_patch_multiple(mask_i, PATCH_SIZE, IMAGE_SIZE)
         mask_i_quantized = quantize_mask(mask_i_resized, PATCH_SIZE)
         ys.append(mask_i_quantized.view(-1).cpu())
         
         # Process Image: Resize & Norm
         image_i = images[i].convert('RGB')
         image_i_resized = resize_to_patch_multiple(image_i, PATCH_SIZE, IMAGE_SIZE)
         image_i_norm = TF.normalize(image_i_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD).unsqueeze(0).to(device)
         
         # Real Feature Extraction
         feats = extract_dino_features(model, image_i_norm)
         
         dim = feats.shape[1]
         xs.append(feats.squeeze().view(dim, -1).permute(1,0).cpu())
         
         image_index.append(i * torch.ones(ys[-1].shape))

if len(xs) > 0:
    xs = torch.cat(xs)
    ys = torch.cat(ys)
    image_index = torch.cat(image_index)

    # Filter ambiguous labels (edges)
    idx = (ys < 0.01) | (ys > 0.99)
    xs = xs[idx]
    ys = ys[idx]
    image_index = image_index[idx]
    
    print("Design matrix:", xs.shape)
    print("Label matrix:", ys.shape)

## 3. Leave-One-Out Cross-Validation
Train LRs with different C values on N-1 images, test on 1.

In [None]:
if len(xs) > 0:
    cs = np.logspace(-7, 0, 8)
    scores = np.zeros((n_images, len(cs)))

    for i in range(n_images):
        print(f'Validation using image_{i+1:02d}')
        
        # Train/Val Split
        train_selection = image_index != float(i)
        fold_x = xs[train_selection].numpy()
        fold_y = (ys[train_selection] > 0).long().numpy()
        val_x = xs[~train_selection].numpy()
        val_y = (ys[~train_selection] > 0).long().numpy()

        for j, c in enumerate(cs):
             # print(f"Training C={c:.2e}")
             clf = LogisticRegression(random_state=0, C=c, max_iter=1000).fit(fold_x, fold_y)
             output = clf.predict_proba(val_x)
             s = average_precision_score(val_y, output[:, 1])
             scores[i, j] = s

    # Plot Average Scores to find best C
    plt.figure(figsize=(5, 3))
    plt.plot(scores.mean(axis=0))
    plt.xticks(np.arange(len(cs)), [f"{c:.0e}" for c in cs])
    plt.xlabel('C')
    plt.ylabel('Average AP')
    plt.grid()
    plt.title("Cross-Validation Results")
    plt.show()

## 4. Final Training & Saving
Train with optimal C (usually 0.1 or 1.0) on all data and save.

In [None]:
if len(xs) > 0:
    # Picking C=0.1 as per tutorial suggestion
    print("Retraining with C=0.1 on full dataset...")
    final_clf = LogisticRegression(random_state=0, C=0.1, max_iter=5000).fit(xs.numpy(), (ys > 0).long().numpy())
    
    # Save
    with open("fg_classifier.pkl", "wb") as f:
        pickle.dump(final_clf, f)
    print("Saved fg_classifier.pkl")

## 5. Test Inference with Median Filter

In [None]:
TEST_IMAGE_URI = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d7/Green_Sea_Turtle_grazing_seagrass.jpg/640px-Green_Sea_Turtle_grazing_seagrass.jpg"

def load_image_from_url(url: str) -> Image:
    try:
        with urllib.request.urlopen(url) as f:
            return Image.open(f).convert("RGB")
    except:
        return Image.new('RGB', (500, 500), color='blue')

test_img = load_image_from_url(TEST_IMAGE_URI)
test_img_resized = resize_to_patch_multiple(test_img, PATCH_SIZE, IMAGE_SIZE)
test_norm = TF.normalize(test_img_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD).unsqueeze(0).to(device)

with torch.inference_mode():
    # Real feature extraction for test image
    feats = extract_dino_features(model, test_norm)

    x_test = feats.squeeze().view(dim, -1).permute(1, 0).cpu().numpy()
    
    h, w = test_img_resized.shape[1] // PATCH_SIZE, test_img_resized.shape[2] // PATCH_SIZE

if 'final_clf' in locals():
    probs = final_clf.predict_proba(x_test)[:, 1].reshape(h, w)
    probs_mf = signal.medfilt2d(probs, kernel_size=3)

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 3, 1); plt.imshow(test_img); plt.title("Input")
    plt.subplot(1, 3, 2); plt.imshow(probs); plt.title("Raw Probs")
    plt.subplot(1, 3, 3); plt.imshow(probs_mf); plt.title("+ Median Filter")
    plt.show()