In [1]:
from json import load
import torch
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor,Compose, Resize, Normalize
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from torch import nn
import torch.nn.functional as F
from PIL import Image
import os 
import matplotlib.pyplot as plt
import numpy as np
from pprint import pprint as pp



### Model Analysis
Modelos para extração de carcateristicas com DINO- Versões 1 e 2 e suas variações


In [2]:

#Dino V1
vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8') # funcionando
vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')

# DINOv2
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')

# DINOv2 with registers
dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')

# xcit_small_12_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p16')
# xcit_small_12_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p8')
# xcit_medium_24_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16')
# xcit_medium_24_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
# resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dino_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main
Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main


# DINO Perceptual Losses

In [3]:
import torch
import torch.nn as nn

class PerceptualLoss(nn.Module):
    def __init__(self, model, layers=None, normalize_inputs=False):
        """
        Perceptual loss using DINO or DINOv2 models.

        Args:
            model (torch.nn.Module): The DINO model to extract features.
            layers (list of str): Names of layers to use for loss computation.
                                  Default is None, which uses all layers.
            normalize_inputs (bool): Whether to normalize inputs to [0, 1].
        """
        super(PerceptualLoss, self).__init__()
        self.model = model.eval()  # Set model to evaluation mode
        self.layers = layers
        self.normalize_inputs = normalize_inputs

        # Disable gradient computation for the model
        for param in self.model.parameters():
            param.requires_grad = False

    def _ensure_tensor(self, feat):
            if isinstance(feat, tuple):
                feat = feat[0]
            return feat
    
    def extract_features(self, x):
        """
        Extract features from the model.

        Args:
            x (torch.Tensor): Input image tensor (B, C, H, W).

        Returns:
            list of torch.Tensor: Extracted features.
        """
        features = []
        hooks = []

        # Hook to extract features from specified layers
        def hook_fn(module, input, output):
            features.append(output)

        # Register hooks on specified layers or use all layers
        if self.layers is None:
            for name, module in self.model.named_modules():
                hooks.append(module.register_forward_hook(hook_fn))
        else:
            for name, module in self.model.named_modules():
                if name in self.layers:
                    hooks.append(module.register_forward_hook(hook_fn))

        # Forward pass to get features
        self.model(x)

        # Remove hooks
        for hook in hooks:
            hook.remove()

        return features


    def forward(self, input, target):
        
        """
        Compute perceptual loss between input and target images.

        Args:
            input (torch.Tensor): Input image tensor (B, C, H, W).
            target (torch.Tensor): Target image tensor (B, C, H, W).

        Returns:
            torch.Tensor: Perceptual loss value.
        # """
        if self.normalize_inputs:
            input = (input - input.min()) / (input.max() - input.min())
            target = (target - target.min()) / (target.max() - target.min())

        # Extract features
        input_features = self.extract_features(input)
        target_features = self.extract_features(target)

        # Compute loss
        loss = 0
        for inp_feat, tgt_feat in zip(input_features, target_features):
            
            inp_feat = self._ensure_tensor(inp_feat)
            tgt_feat = self._ensure_tensor(tgt_feat)

            loss += nn.functional.smooth_l1_loss(inp_feat, tgt_feat).to('cpu')
        return loss


In [91]:
import time
# Função para medir o tempo de execução
def measure_execution_time(model, input, target):
    start_time = time.time()
    loss = model(input, target)
    end_time = time.time()
    print(f"Tempo de execução: {end_time - start_time} segundos")
    return loss

### Funcionamento

Esse codigo funciona para extrair features e calacular uma loss function baseada nessa comparação de features com uma função de perda. Aplicados assim nas arquiteturas pre-treinadas de versoes do DINO. O objetivo é usar versões diferentes do DINO como funções perceptuais. 
* Funcionando
    * DINO V1 loss functions completmante funcionais. (DONE)
    * DINO V2
    * DINO V2 with registers

    
**Observações**


Esta sendo usada a smooth L1 por sua melhor aplicabilidade se comparada a MSE e a L! individualmente.

    




In [None]:
#Dino V1
# vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8') # funcionando
# vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
# vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
# vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')

# # DINOv2
# dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
# dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
# dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
# dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')

# # DINOv2 with registers
# dinov2_vits14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
# dinov2_vitb14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
# dinov2_vitl14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
# dinov2_vitg14_reg = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')




### DINO V1

In [92]:
# Exemplos de uso
# Escolha o modelo desejado
model = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')

# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dino_main


Input image: torch.Size([1, 3, 224, 224]) tensor(7.4506e-06) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(2.3842e-06) tensor(1.0000)
Perda perceptual: 37.65395736694336
Tempo de execução: 0.41363000869750977 segundos


tensor(37.6540)

In [93]:

# Exemplos de uso
# Escolha o modelo desejado
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')

# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dino_main


Input image: torch.Size([1, 3, 224, 224]) tensor(3.0994e-06) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(4.1723e-06) tensor(1.0000)
Perda perceptual: 20.65361213684082
Tempo de execução: 1.1973350048065186 segundos


tensor(20.6536)

In [95]:

# Exemplos de uso
# Escolha o modelo desejado
model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')

# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dino_main


Input image: torch.Size([1, 3, 224, 224]) tensor(1.1921e-05) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(1.0192e-05) tensor(1.0000)
Perda perceptual: 37.05283737182617
Tempo de execução: 0.10273957252502441 segundos


tensor(37.0528)

In [None]:
# Exemplos de uso
# Escolha o modelo desejado
model = torch.hub.load('facebookresearch/dino:main', 'dinov2_vits14')

# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dino_main


Input image: torch.Size([1, 3, 224, 224]) tensor(2.1458e-06) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(3.9935e-06) tensor(1.0000)
Perda perceptual: 24.463768005371094
Tempo de execução: 0.25011420249938965 segundos


tensor(24.4638)

### DINO V2

In [98]:
# Exemplos de uso
# Escolha o modelo desejado
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')


# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main


Input image: torch.Size([1, 3, 224, 224]) tensor(5.2452e-06) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(2.3246e-06) tensor(1.0000)
Perda perceptual: 18.904483795166016
Tempo de execução: 0.11289119720458984 segundos


tensor(18.9045)

In [99]:
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')

# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main


Input image: torch.Size([1, 3, 224, 224]) tensor(1.0490e-05) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(2.6226e-06) tensor(1.0000)
Perda perceptual: 13.746903419494629
Tempo de execução: 0.34348106384277344 segundos


tensor(13.7469)

In [100]:

model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')

# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main


Input image: torch.Size([1, 3, 224, 224]) tensor(1.8418e-05) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(5.4240e-06) tensor(1.0000)
Perda perceptual: 23.665300369262695
Tempo de execução: 1.2107529640197754 segundos


tensor(23.6653)

In [101]:

model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')


# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main


Input image: torch.Size([1, 3, 224, 224]) tensor(5.5432e-06) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(2.1160e-05) tensor(1.0000)
Perda perceptual: 33.627864837646484
Tempo de execução: 3.510894536972046 segundos


tensor(33.6279)

### DINO V2 with registers

In [103]:
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg')
# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)


Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main


Input image: torch.Size([1, 3, 224, 224]) tensor(6.9141e-06) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(3.8147e-06) tensor(1.0000)
Perda perceptual: 14.188370704650879
Tempo de execução: 0.1164255142211914 segundos


tensor(14.1884)

In [104]:
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)


Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main


Input image: torch.Size([1, 3, 224, 224]) tensor(7.5698e-06) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(1.7881e-07) tensor(1.0000)
Perda perceptual: 15.03640079498291
Tempo de execução: 0.35260581970214844 segundos


tensor(15.0364)

In [105]:

model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg')
# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main


Input image: torch.Size([1, 3, 224, 224]) tensor(1.5199e-05) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(1.0133e-06) tensor(1.0000)
Perda perceptual: 33.44691467285156
Tempo de execução: 1.1680426597595215 segundos


tensor(33.4469)

In [106]:

model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_reg')
# Instancie a função de perda perceptual
perceptual_loss = PerceptualLoss(model=model,normalize_inputs=False)

# Defina imagens de exemplo
input_image = torch.rand(1, 3, 224, 224)  # Batch de 1 imagem (C, H, W)
target_image = torch.rand(1, 3, 224, 224)

print("Input image:", input_image.shape, input_image.min(), input_image.max())
print("Target image:", target_image.shape, target_image.min(), target_image.max())
# Compute a perda perceptual
loss_value = perceptual_loss(input_image, target_image)
print("Perda perceptual:", loss_value.item())

measure_execution_time(perceptual_loss, input_image, target_image)

Using cache found in /home/pdi_4/.cache/torch/hub/facebookresearch_dinov2_main


Input image: torch.Size([1, 3, 224, 224]) tensor(3.0994e-06) tensor(1.0000)
Target image: torch.Size([1, 3, 224, 224]) tensor(1.1325e-06) tensor(1.0000)
Perda perceptual: 38.798404693603516
Tempo de execução: 3.630220413208008 segundos


tensor(38.7984)