# DINOv3 Quantization Tutorial

This tutorial demonstrates how to optimize DINOv3 models using Int4 quantization with `torchao`.

We will covers:
1. **Model Loading**: Loading a standard DINOv3 ViT-B model.
2. **Quantization**: Compressing the model weights to 4-bit integers.
3. **Benchmarking**: Comparing Memory usage and Latency between FP32 and Int4.
4. **Visualization**: verifying that the attention maps (and thus semantic understanding) are preserved.

In [None]:
import torch
import time
import numpy as np
import matplotlib.pyplot as plt
from dinov3production import create_model
from dinov3production.quantize.quantizer import Quantizer
import dinov3production.visualization.image as viz

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load Pretrained Model

In [None]:
model_name = 'dinov3_vitb14'
model = create_model(model_name, pretrained=True).to(device).eval()
print(f"Loaded {model_name}")

## 2. Quantization (Int4)
Use the `Quantizer` utility to apply Int4 weight-only quantization.

In [None]:
quantizer = Quantizer(model)
q_model = quantizer.to_int4()
print("Model quantized to Int4.")

## 3. Benchmarking
Let's measure the improvements. We expect lower memory usage and potentially faster inference on newer efficient hardware.

In [None]:
def benchmark(model, input_tensor, n_warmup=10, n_runs=100):
    # Warmup
    for _ in range(n_warmup):
        _ = model(input_tensor)
    
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    
    for _ in range(n_runs):
        _ = model(input_tensor)
        
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    
    avg_time_ms = (end_time - start_time) * 1000 / n_runs
    return avg_time_ms

def get_model_size_mb(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_all_mb = (param_size + buffer_size) / 1024**2
    return size_all_mb

# Create Dummy Input
dummy_input = torch.randn(1, 3, 224, 224).to(device)

# Compare sizes (Approximate for Int4 as torch sees float wrapper often, but let's check)
print(f"Original Model Size: {get_model_size_mb(model):.2f} MB")
print(f"Quantized Model Size: {get_model_size_mb(q_model):.2f} MB (Estimated compressed)")

# Measure Latency
lat_fp32 = benchmark(model, dummy_input)
lat_int4 = benchmark(q_model, dummy_input)

print(f"FP32 Latency: {lat_fp32:.2f} ms")
print(f"Int4 Latency: {lat_int4:.2f} ms")

## 4. Visualization
To ensure quality is preserved, let's visualize the attention maps of the last block for both models.

In [None]:
# Load a real image or use dummy noise
try:
    import requests
    from io import BytesIO
    from PIL import Image
    url = "https://dl.fbaipublicfiles.com/dinov2/images/example.jpg"
    response = requests.get(url)
    img_pil = Image.open(BytesIO(response.content))
except:
    img_pil = Image.new('RGB', (224, 224), color='red')

# Preprocess
import torchvision.transforms as T
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img_tensor = transform(img_pil).unsqueeze(0).to(device)

# Hook to capture attention
def get_attention(m, x):
    # This is a simplification. 
    # For DINOv3, we can use get_intermediate_layers or get_last_selfattention
    # Assuming the method exists on the ViT model
    return m.get_last_selfattention(x)

# Note: The provided API wrapper might need direct access. 
# If get_last_selfattention exists:
try:
    with torch.inference_mode():
        att_fp32 = model.get_last_selfattention(img_tensor)
        att_int4 = q_model.get_last_selfattention(img_tensor)
        
    # Keep only CLS attention for visualization [B, H, N, N] -> [B, H, 0, 1:] reshaped
    # Visualize mean head attention for patch tokens wrt CLS token
    nh = att_fp32.shape[1]
    
    # [0, :, 0, 1:] -> Mean over heads -> Reshape to 16x16
    att_map_fp32 = att_fp32[0, :, 0, 1:].mean(0).reshape(14, 14).cpu().numpy()
    att_map_int4 = att_int4[0, :, 0, 1:].mean(0).reshape(14, 14).cpu().numpy()
    
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(img_pil.resize((224, 224)))
    ax[0].set_title("Input Image")
    
    ax[1].imshow(viz.visualize_attention(np.array(img_pil.resize((224, 224))), att_map_fp32))
    ax[1].set_title("FP32 Attention")
    
    ax[2].imshow(viz.visualize_attention(np.array(img_pil.resize((224, 224))), att_map_int4))
    ax[2].set_title("Int4 Attention")
    plt.show()
    
except AttributeError:
    print("Model does not support get_last_selfattention directly or is wrapped unpredictably.")