In [28]:
import PIL
from datasets import load_dataset
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from torchvision.transforms import ToPILImage
from torchvision import transforms
from torchmetrics.classification import MulticlassJaccardIndex
import numpy as np
from transformers import AutoModelForImageClassification

In [2]:
ade20k_dino = load_dataset("danjacobellis/ade20k_dino",split='validation').with_format("torch")

Resolving data files:   0%|          | 0/65 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/65 [00:00<?, ?it/s]

In [3]:
ade20k = load_dataset("scene_parse_150",split='validation')

In [31]:
model = AutoModelForImageClassification.from_pretrained('facebook/dinov2-giant-imagenet1k-1-layer')
model = model.cuda()

In [74]:
class LinearSegmentationHead(nn.Module):
    def __init__(self, in_channels=1536, num_classes=150):
        super(LinearSegmentationHead, self).__init__()
        self.conv_seg = nn.Conv2d(in_channels, num_classes, kernel_size=1)
        self.bn = nn.BatchNorm2d(in_channels)

    def forward(self, x):
        x = self.bn(x)
        x = self.conv_seg(x)
        return x

In [86]:
checkpoint = torch.load("dinov2_vitg14_ade20k_linear_head.pth")
new_state_dict = {key.replace("decode_head.", ""): value for key, value in checkpoint['state_dict'].items()}
head = LinearSegmentationHead()
head.load_state_dict(new_state_dict)
head = head.to("cuda")

In [None]:
iou = []
for sample in ade20k:
    img = sample['image'].resize((224,224))
    ground_truth = sample['annotation']
    x = transforms.ToTensor()(img).unsqueeze(0).to("cuda")
    with torch.no_grad():
        
        y = model.dinov2.forward(x)[0]
        cls_token = y[:, 0].detach()
        patch_tokens = y[:, 1:].detach()
        patch_tokens = patch_tokens.reshape((1,16,16,1536)).permute((0,3,1,2))
        patch_tokens = nnf.interpolate(patch_tokens,
                                       size=(ground_truth.height,ground_truth.width),
                                       mode='bicubic',
                                       align_corners=True
                                      )
        logits = head(patch_tokens)
        predicted = ToPILImage()(logits[0].argmax(dim=0).to(torch.uint8))
    
        x1 = transforms.PILToTensor()(ground_truth) 
        x2 = transforms.PILToTensor()(predicted)
        x1 = x1-1

        iou.append(
            MulticlassJaccardIndex(num_classes=151,average='micro',ignore_index=255)
            (x1,x2).item()
        )