In [None]:
import einops
import matplotlib.pyplot as plt
import requests
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as Ftv
from PIL import Image
from transformers import AutoImageProcessor, AutoModel

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True, timeout=10).raw)

processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
model = AutoModel.from_pretrained("facebook/dinov2-base")

inputs = processor(images=image, return_tensors="pt")
inputs.data["pixel_values"] = F.interpolate(
    inputs.data["pixel_values"], size=(448, 448), mode="bicubic"
)

inputs_processed = inputs.data["pixel_values"].clone()

outputs = model(**inputs)
last_hidden_states = outputs[0]

# We have to force return_dict=False for tracing
model.config.return_dict = False

with torch.no_grad():
    traced_model = torch.jit.trace(model, [inputs.pixel_values])
    traced_outputs = traced_model(inputs.pixel_values)

print((last_hidden_states - traced_outputs[0]).abs().max())

# print image and it's features

print(last_hidden_states.shape)

wo_cls = last_hidden_states[:, 1:]
h, w = 32, 32
wo_cls = einops.rearrange(wo_cls, "b p c -> (b p) c")

# do pca with torch.pca_lowrank
U, S, V = torch.pca_lowrank(wo_cls, q=3)
wo_cls = wo_cls @ V[:, :3]
mins = wo_cls.amin(dim=[0])
maxs = wo_cls.amax(dim=[0])
wo_cls = (wo_cls - mins) / (maxs - mins) * 1.0

wo_cls = einops.rearrange(wo_cls, "(b h w) c -> b c h w", h=h, w=w)

plt.imshow(Ftv.to_pil_image(wo_cls[0]))
plt.axis("off")
plt.show()

plt.imshow()
plt.axis("off")
plt.show()