In [1]:
import PIL
from datasets import load_dataset
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from torchvision.transforms import ToPILImage
from torchvision import transforms
from torchmetrics.classification import MulticlassJaccardIndex
import numpy as np
from transformers import AutoModel

In [2]:
ade20k = load_dataset("scene_parse_150",split='validation')

In [3]:
model = AutoModel.from_pretrained('facebook/dinov2-giant')
model = model.cuda()

In [4]:
class LinearSegmentationHead(nn.Module):
    def __init__(self, in_channels=1536, num_classes=150):
        super(LinearSegmentationHead, self).__init__()
        self.conv_seg = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.conv_seg(x)
        return x

In [5]:
checkpoint = torch.load("dinov2_vitg14_ade20k_linear_head.pth")
new_state_dict = {key.replace("decode_head.", ""): value for key, value in checkpoint['state_dict'].items()}
new_state_dict.pop('bn.weight', None)
new_state_dict.pop('bn.bias', None)
new_state_dict.pop('bn.running_var', None)
new_state_dict.pop('bn.running_mean', None)
new_state_dict.pop('bn.num_batches_tracked', None)
head = LinearSegmentationHead()
head.load_state_dict(new_state_dict)
head = head.to("cuda")

In [12]:
%%time
iou = []
for i_sample,sample in enumerate(ade20k):
    img = sample['image'].resize((448,252))
    ground_truth = sample['annotation']
    while ground_truth.width > 1000:
        ground_truth = ground_truth.resize((ground_truth.width//2,ground_truth.height//2),PIL.Image.Resampling.NEAREST)
    x = transforms.ToTensor()(img).unsqueeze(0).to("cuda")
    with torch.no_grad():
        
        y = model.forward(x)[0]
        cls_token = y[:, 0]
        patch_tokens = y[:, 1:]
        patch_tokens = patch_tokens.reshape((1,18,32,1536)).permute((0,3,1,2))
        patch_tokens = nnf.interpolate(patch_tokens,
                                       size=(ground_truth.height,ground_truth.width),
                                       mode='bilinear')
        logits = head(patch_tokens)
        predicted = ToPILImage()(logits[0].argmax(dim=0).to(torch.uint8))
        
        x1 = transforms.PILToTensor()(ground_truth) 
        x2 = transforms.PILToTensor()(predicted)
        x1 = x1-1

        iou.append(
            MulticlassJaccardIndex(num_classes=151,average='micro',ignore_index=255)
            (x1,x2).item()
        )

CPU times: user 5min 39s, sys: 814 ms, total: 5min 40s
Wall time: 4min 43s


In [13]:
# 448x252
np.mean(iou)

0.4219338969966702