# Dino for perceptual loss

Download the backbones

In [None]:
import torch

vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
#vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
#vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
#vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')

#xcit_small_12_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p16')
#xcit_small_12_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p8')
#xcit_medium_24_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16')
#xcit_medium_24_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')

resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

Using cache found in /home/rtxmsi1/.cache/torch/hub/facebookresearch_dino_main


In [None]:
resnet50

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

Read the image

In [24]:
from torchvision import transforms as pth_transforms
from PIL import Image

img = Image.open("/home/rtxmsi1/Documents/DINO/src/Vd-Orig.png")
img = img.convert('RGB')
transform = pth_transforms.Compose([
        pth_transforms.Resize(128),
        pth_transforms.ToTensor(),
        pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

img = transform(img)
img.shape

torch.Size([3, 128, 128])

### Using CLS token

Pros: Designed as a global summary

In [7]:
# https://github.com/facebookresearch/dino/blob/main/eval_linear.py#L153
avgpool = False
n = 10

with torch.no_grad():
    if "vit":
        intermediate_output = vits16.get_intermediate_layers(img.unsqueeze(0), n)
        # Get the CLS token for each intermediate output
        output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        if avgpool:
            output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
            output = output.reshape(output.shape[0], -1)
    else:
        output = vits16(img)

output.shape

torch.Size([1, 3840])

In [28]:
output = resnet50(img.unsqueeze(0))
output.shape

torch.Size([1, 2048])

In [9]:
import torch

x = (torch.rand(1, 3840) * 20) - 10  # Scale [0,1) → [0,20) → [-10,10)
print(x.shape)  # torch.Size([1, 3840])
print(x.min(), x.max())  # Should be near -10 and 10

torch.Size([1, 3840])
tensor(-9.9963) tensor(9.9970)


In [12]:
import torch.nn.functional as F
mse = F.mse_loss(output, x, reduction="sum")
mse

tensor(191322.4375)

### Mean‑pooled tokens

What: average over all tokens (patch embeddings)

In [20]:
intermediate_output[1].mean(dim=1).shape

mse = F.mse_loss(intermediate_output[1].mean(dim=1), intermediate_output[1].mean(dim=1), reduction="sum")
mse

tensor(0.)

### Full token map (spatial loss)
Pros: Preserves spatial correspondence, analogous to VGG “feature map” loss

In [21]:
intermediate_output[1]

mse = F.mse_loss(intermediate_output[1], intermediate_output[1], reduction="sum")
mse

tensor(0.)

### Function

In [None]:
import torch
import torch.nn.functional as F

def dino_perceptual_loss(
    x_real,
    x_recon,
    dino_model,
    layer_ids=[11],
    mode='cls',         # 'cls', 'mean', or 'tokens'
    reduction='mean'    # or 'none'
):
    """
    Compute perceptual loss between x_real and x_recon using DINO ViT features.

    Args:
        x_real (Tensor): Original image batch [B, 3, H, W]
        x_recon (Tensor): Reconstructed image batch [B, 3, H, W]
        dino_model (nn.Module): DINO ViT model with get_intermediate_layers
        layer_ids (list[int]): Layer indices to use for perceptual comparison
        mode (str): 'cls' | 'mean' | 'tokens'
        reduction (str): 'mean' | 'sum' | 'none'

    Returns:
        Tensor: Scalar loss (or per-sample if reduction='none')
    """
    # Get intermediate layers
    with torch.no_grad():
        feats_real = dino_model.get_intermediate_layers(x_real, n=len(dino_model.blocks)+1)
        feats_recon = dino_model.get_intermediate_layers(x_recon, n=len(dino_model.blocks)+1)

    loss = 0.0

    for layer in layer_ids:
        f_real = feats_real[layer]  # [B, T, D]
        f_recon = feats_recon[layer]

        if mode == 'cls':
            v_real = f_real[:, 0]    # CLS token
            v_recon = f_recon[:, 0]

        elif mode == 'mean':
            v_real = f_real.mean(dim=1)
            v_recon = f_recon.mean(dim=1)

        elif mode == 'tokens':
            v_real = f_real
            v_recon = f_recon

        else:
            raise ValueError(f"Unknown mode: {mode}")

        loss += F.mse_loss(v_real, v_recon, reduction=reduction)

    return loss / len(layer_ids)
