<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/DINOv2_Encoder_Downstream_Tasks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

DINOv2 for Classification:

In [16]:
from transformers import AutoImageProcessor, Dinov2Model
import torch
import torch.nn as nn
from PIL import Image
import requests

class DINO_Classification(nn.Module):
    def __init__(self, number_classes=1000):
        super(DINO_Classification, self).__init__()
        self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base")
        self.classifier = nn.Linear(self.dinov2.config.hidden_size * 2, number_classes)

    def forward(self, input):
        outputs = self.dinov2(input)
        sequence_output = outputs[0]  # batch_size, sequence_length, hidden_size
        cls_token = sequence_output[:, 0]
        patch_tokens = sequence_output[:, 1:]
        linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
        logits = self.classifier(linear_input)

        return logits

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
model = DINO_Classification(number_classes=20)

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(inputs['pixel_values'])

print('logits:', logits.shape)

logits: torch.Size([1, 20])


DINOv2 for Regression:

In [20]:
from transformers import AutoImageProcessor, Dinov2Model
import torch
import torch.nn as nn
from PIL import Image
import requests

class DINO_Regression(nn.Module):
    def __init__(self, ):
        super(DINO_Regression, self).__init__()
        self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base")
        self.regressor = nn.Sequential(
            nn.Linear(self.dinov2.config.hidden_size * 2, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        outputs = self.dinov2(input)
        sequence_output = outputs[0]  # batch_size, sequence_length, hidden_size
        cls_token = sequence_output[:, 0]
        patch_tokens = sequence_output[:, 1:]
        linear_input = torch.cat([cls_token, patch_tokens.mean(dim=1)], dim=1)
        logits = self.regressor(linear_input)

        return logits

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
model = DINO_Regression()

inputs = image_processor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(inputs['pixel_values'])

print('logits:', logits)

logits: tensor([[0.7806]])
