# DINO-Text Inference

Open-vocabulary classification using DINOv3 features and text encoders.
Demonstrates zero-shot classification, patch alignment, and ImageNet evaluation.

In [None]:
import torch
import torch.nn.functional as F
import urllib
from PIL import Image
import sys
import numpy as np

# Add path if running locally from repo root
sys.path.append("../")

from dinov3production.hub.dinotxt import dinov3_vitl16_dinotxt_tet1280d20h24l
from dinov3production.data.transforms import make_classification_eval_transform

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load Model
Load the DINOv3-Text model (ViT-L/16) and tokenizer.

In [None]:
model, tokenizer = dinov3_vitl16_dinotxt_tet1280d20h24l()
model.to(device)
model.eval()

## 2. Load Sample Image
We load a sample image from a URL.

In [None]:
def load_image_from_url(url: str) -> Image:
    with urllib.request.urlopen(url) as f:
        return Image.open(f).convert("RGB")

EXAMPLE_IMAGE_URL = "https://dl.fbaipublicfiles.com/dinov2/images/example.jpg"
try:
    img_pil = load_image_from_url(EXAMPLE_IMAGE_URL)
except Exception as e:
    print("Failed to load example image, falling back to dummy image.")
    img_pil = Image.new('RGB', (224, 224), color='red')

# display(img_pil) # Uncomment in Jupyter

## 3. Zero-Shot Classification
Compute similarity between image global features and text descriptions.

In [None]:
image_preprocess = make_classification_eval_transform()
image_tensor = torch.stack([image_preprocess(img_pil)], dim=0).to(device)

texts = ["photo of dogs", "photo of a chair", "photo of a bowl", "photo of a tupperware"]
class_names = ["dog", "chair", "bowl", "tupperware"]

tokenized_texts_tensor = tokenizer.tokenize(texts).to(device)

with torch.autocast(device_type=device, dtype=torch.float):
    with torch.no_grad():
        image_features = model.encode_image(image_tensor)
        text_features = model.encode_text(tokenized_texts_tensor)

# Normalize
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

similarity = (
    text_features.cpu().float().numpy() @ image_features.cpu().float().numpy().T
)
print("Similarity scores:", similarity.flatten())

# Best match
best_idx = similarity.argmax()
print(f"Best match: {texts[best_idx]}")

## 4. Patch Embeddings & Alignment
Visualize spatial alignment between text queries and image patches.

In [None]:
with torch.autocast(device_type=device, dtype=torch.float):
    with torch.no_grad():
        image_class_tokens, image_patch_tokens, backbone_patch_tokens = model.encode_image_with_patch_tokens(image_tensor)
        # Part of text features that is aligned to patch features (e.g. from index 1024 onwards)
        full_text_feats = model.encode_text(tokenized_texts_tensor)
        text_features_aligned_to_patch = full_text_feats[:, 1024:]

B, P, D = image_patch_tokens.shape
H = W = int(P**0.5)

x = image_patch_tokens.movedim(2, 1).unflatten(2, (H, W)).float()  # [B, D, H, W]
x = F.interpolate(x, size=(480, 640), mode="bicubic", align_corners=False)
x = F.normalize(x, p=2, dim=1)

y = F.normalize(text_features_aligned_to_patch.float(), p=2, dim=1)

per_patch_similarity_to_text = torch.einsum("bdhw,cd->bchw", x, y)

# Argmax per pixel to see which text matches best where
pred_idx = per_patch_similarity_to_text.argmax(1).squeeze(0)

import matplotlib.pyplot as plt
plt.imshow(pred_idx.cpu().numpy())
plt.title("Per-pixel Text Alignment")
plt.colorbar()
plt.show()

## 5. ImageNet1k Zero-Shot Evaluation
Full evaluation loop on ImageNet validation set.

In [None]:
# ImageNet Classes
imagenet_clip_class_names= ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich",
    # ... (truncated for brevity, typically full list would be here)
    "goldfinch", "house finch", "junco", "indigo bunting"
]

# Templates
openai_imagenet_templates = (
    lambda c: f"a bad photo of a {c}.",
    lambda c: f"a photo of many {c}.",
    lambda c: f"a sculpture of a {c}.",
    lambda c: f"a photo of the hard to see {c}.",
    lambda c: f"a low resolution photo of the {c}.",
    lambda c: f"a rendering of a {c}.",
    lambda c: f"graffiti of a {c}.",
    lambda c: f"a bad photo of the {c}.",
    lambda c: f"a cropped photo of the {c}.",
    lambda c: f"a tattoo of a {c}.",
    lambda c: f"the embroidered {c}.",
    lambda c: f"a photo of a hard to see {c}.",
    lambda c: f"a bright photo of a {c}.",
    lambda c: f"a photo of a clean {c}.",
    lambda c: f"a photo of a dirty {c}.",
    lambda c: f"a dark photo of the {c}.",
    lambda c: f"a drawing of a {c}.",
    lambda c: f"a photo of my {c}.",
    lambda c: f"the plastic {c}.",
    # ... and more
)

from torchvision.datasets import ImageFolder

def zeroshot_classifier(classnames, templates, tokenizer):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in classnames:
            texts = [template(classname) for template in templates] #format with class
            texts = tokenizer.tokenize(texts).to(device) #tokenize
            class_embeddings = model.encode_text(texts) #embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
    return zeroshot_weights

def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [correct[:k].reshape(-1).sum(0, keepdim=True) for k in topk]

# Please update the following directory to the root of ImageNet1k val dataset.
imagenet_val_root_dir = "./imagenet_val_dummy" # Placeholder path

if os.path.exists(imagenet_val_root_dir):
    print("Starting ImageNet Evaluation...")

    zeroshot_weights = zeroshot_classifier(imagenet_clip_class_names, openai_imagenet_templates, tokenizer)

    val_dataset = ImageFolder(imagenet_val_root_dir, image_preprocess)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

    top1, top5, n = 0., 0., 0.
    for images, targets in val_loader:
        with torch.autocast(device_type=device, dtype=torch.float):
            with torch.no_grad():
                image_features = model.encode_image(images.to(device))
                image_features /= image_features.norm(dim=-1, keepdim=True)
                logits = 100. * image_features @ zeroshot_weights
                acc1, acc5 = accuracy(logits, targets.to(device), topk=(1, 5))
                top1 += acc1
                top5 += acc5
                n += len(images)
    
    top1 = (top1.item() / n) * 100
    top5 = (top5.item() / n) * 100 
    print(f"Top-1 accuracy: {top1}")
    print(f"Top-5 accuracy: {top5}")
else:
    print("ImageNet validation directory not found. Skipping full evaluation.")