In [1]:
import PIL
from datasets import load_dataset
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from torchvision.transforms import ToPILImage
from torchvision import transforms
from torchmetrics.classification import MulticlassJaccardIndex
import numpy as np
from transformers import AutoModelForImageClassification

In [2]:
ade20k_dino = load_dataset("danjacobellis/ade20k_dino",split='validation').with_format("torch")

Resolving data files:   0%|          | 0/65 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/65 [00:00<?, ?it/s]

In [2]:
ade20k = load_dataset("scene_parse_150",split='validation')

In [3]:
model = AutoModelForImageClassification.from_pretrained('facebook/dinov2-giant')
model = model.cuda()

Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-giant and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14').to("cuda")

In [5]:
class LinearSegmentationHead(nn.Module):
    def __init__(self, in_channels=1536, num_classes=150):
        super(LinearSegmentationHead, self).__init__()
        self.conv_seg = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.conv_seg(x)
        return x

In [6]:
checkpoint = torch.load("dinov2_vitg14_ade20k_linear_head.pth")
new_state_dict = {key.replace("decode_head.", ""): value for key, value in checkpoint['state_dict'].items()}
new_state_dict.pop('bn.weight', None)
new_state_dict.pop('bn.bias', None)
new_state_dict.pop('bn.running_var', None)
new_state_dict.pop('bn.running_mean', None)
new_state_dict.pop('bn.num_batches_tracked', None)
head = LinearSegmentationHead()
head.load_state_dict(new_state_dict)
head = head.to("cuda")

In [14]:
%%time
iou = []
for sample in ade20k:
    img = sample['image'].resize((224,224))
    ground_truth = sample['annotation']
    while ground_truth.width > 1000:
        ground_truth = ground_truth.resize((ground_truth.width//2,ground_truth.height//2),PIL.Image.Resampling.NEAREST)
    x = transforms.ToTensor()(img).unsqueeze(0).to("cuda")
    with torch.no_grad():
        
        y = model.dinov2.forward(x)[0]
        # y = model.forward_features(x)
        
        cls_token = y[:, 0].detach()
        patch_tokens = y[:, 1:].detach()
        # cls_token = y['x_norm_clstoken']
        # patch_tokens = y['x_norm_patchtokens']
        
        patch_tokens = patch_tokens.reshape((1,16,16,1536)).permute((0,3,1,2))
        patch_tokens = nnf.interpolate(patch_tokens,
                                       size=(ground_truth.height,ground_truth.width),
                                       mode='bilinear'
                                      )
        logits = head(patch_tokens)
        predicted = ToPILImage()(logits[0].argmax(dim=0).to(torch.uint8))
        # predicted = predicted.resize((ground_truth.width,ground_truth.height),resample=PIL.Image.Resampling.BILINEAR)
    
        x1 = transforms.PILToTensor()(ground_truth) 
        x2 = transforms.PILToTensor()(predicted)
        x1 = x1-1

        iou.append(
            MulticlassJaccardIndex(num_classes=151,average='micro',ignore_index=255)
            (x1,x2).item()
        )

CPU times: user 3min 38s, sys: 870 ms, total: 3min 39s
Wall time: 2min 44s


In [9]:
# 518x518
np.mean(iou)

0.45564204835228156

In [13]:
# 448x448
np.mean(iou)

0.4499563585400756

In [15]:
# 224x224
np.mean(iou)

0.40693133881329413