# Run on Google Colab using a GPU

## Clone the repo

In [1]:
!rm -rf sample_data/
!mkdir out

In [2]:
!mkdir input
!mkdir output

In [3]:
!git clone https://github.com/facebookresearch/dino.git

Cloning into 'dino'...
remote: Enumerating objects: 175, done.[K
remote: Counting objects: 100% (69/69), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 175 (delta 57), reused 55 (delta 54), pack-reused 106[K
Receiving objects: 100% (175/175), 24.44 MiB | 16.19 MiB/s, done.
Resolving deltas: 100% (110/110), done.


Download a model, here I used deit small 8 pretrained

In [4]:
!wget https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth -O dino/dino_deitsmall8_pretrain.pth

--2023-09-21 19:54:38--  https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.227.219.59, 13.227.219.70, 13.227.219.10, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|13.227.219.59|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 86728949 (83M) [application/zip]
Saving to: ‘dino/dino_deitsmall8_pretrain.pth’


2023-09-21 19:54:38 (195 MB/s) - ‘dino/dino_deitsmall8_pretrain.pth’ saved [86728949/86728949]



## Look for a video to use and download it

I'm using this one for example
https://www.pexels.com/fr-fr/video/chien-course-exterieur-journee-ensoleillee-4166347/


Then you need to extract frames from the video, you can use ffmpeg.

Video is 60 fps and ~6 sec so you'll get ~360 jpg images

%03d is from 001 to 999

In [5]:
!ffmpeg -i /content/video.mp4 /content/input/img-%03d.jpg

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab

In [6]:
%cd dino/

/content/dino


## Code

Requirements:


* Opencv
* scikit-image
* maptlotlib
* pytorch
* numpy
* Pillow



In [7]:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import gc
import sys
import argparse
import cv2
import random
import colorsys
import requests
from io import BytesIO

import skimage.io
from skimage.measure import find_contours
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms as pth_transforms
import numpy as np
from PIL import Image

import utils
import vision_transformer as vits

You may need to resize each tensor otherwise you'll get an OOM error

Line 9: `pth_transforms.Resize(512),`


Also, the color of video from blogpost is obtained by using cmap="inferno"

In [18]:
def predict_video(args):
    for frame in sorted(os.listdir(args.image_path)):
        with open(os.path.join(args.image_path, frame), 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')

        transform = pth_transforms.Compose([
            pth_transforms.ToTensor(),
            pth_transforms.Resize(512),
            pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        img = transform(img)

        # make the image divisible by the patch size
        w, h = img.shape[1] - img.shape[1] % args.patch_size, img.shape[2] - img.shape[2] % args.patch_size
        img = img[:, :w, :h].unsqueeze(0)

        w_featmap = img.shape[-2] // args.patch_size
        h_featmap = img.shape[-1] // args.patch_size

        attentions = model.get_last_selfattention(img.cuda())

        nh = attentions.shape[1] # number of head

        # we keep only the output patch attention
        attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

        # we keep only a certain percentage of the mass
        val, idx = torch.sort(attentions)
        val /= torch.sum(val, dim=1, keepdim=True)
        cumval = torch.cumsum(val, dim=1)
        th_attn = cumval > (1 - args.threshold)
        idx2 = torch.argsort(idx)
        for head in range(nh):
            th_attn[head] = th_attn[head][idx2[head]]
        th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
        # interpolate
        th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()

        attentions = attentions.reshape(nh, w_featmap, h_featmap)
        attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()

        # save attentions heatmaps
        os.makedirs(args.output_dir, exist_ok=True)

        # Saving only last attention layer
        fname = os.path.join(args.output_dir, "attn-" + frame)
        plt.imsave(
            fname=fname,
            arr=sum(attentions[i] * 1/attentions.shape[0] for i in range(attentions.shape[0])),
            cmap="inferno",
            format="jpg"
        )
        print(f"{fname} saved.")

In [19]:
#@title Args

pretrained_weights_path = "dino_deitsmall8_pretrain.pth" #@param {type:"string"}
arch = 'vit_base' #@param ["deit_small", "deit_tiny", "vit_base"]
input_path = "../input/" #@param {type:"string"}
output_path = "../output/" #@param {type:"string"}
threshold = 0.6 #@param {type:"number"}


parser = argparse.ArgumentParser('Visualize Self-Attention maps')
parser.add_argument('--arch', default='deit_small', type=str,
    choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')
parser.add_argument('--pretrained_weights', default='', type=str,
    help="Path to pretrained weights to load.")
parser.add_argument("--checkpoint_key", default="teacher", type=str,
    help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.")
parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.')
parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks
    obtained by thresholding the self-attention maps to keep xx% of the mass.""")

args = parser.parse_args(args=[])

args.arch = arch
args.pretrained_weights = pretrained_weights_path
args.image_path = "../input/"
args.output_dir = "../out/"
args.threshold = threshold

In [20]:
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
for p in model.parameters():
    p.requires_grad = False
model.eval()
model.cuda()
if os.path.isfile(args.pretrained_weights):
    state_dict = torch.load(args.pretrained_weights, map_location="cpu")
    #if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
    #    print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
    #    state_dict = state_dict[args.checkpoint_key]
    #state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    #msg = model.load_state_dict(state_dict, strict=False)
    #print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))
else:
    print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
    url = None
    if args.arch == "deit_small" and args.patch_size == 16:
        url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
    elif args.arch == "deit_small" and args.patch_size == 8:
        url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"  # model used for visualizations in our paper
    elif args.arch == "vit_base" and args.patch_size == 16:
        url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
    elif args.arch == "vit_base" and args.patch_size == 8:
        url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
    if url is not None:
        print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
        state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
        model.load_state_dict(state_dict, strict=True)
    else:
        print("There is no reference weights available for this model => We use random weights.")


In [21]:
torch.cuda.empty_cache()
gc.collect()

59

## Run inference


Resize if OOM

In [22]:
predict_video(args)

../out/attn-img-001.jpg saved.
../out/attn-img-002.jpg saved.
../out/attn-img-003.jpg saved.
../out/attn-img-004.jpg saved.
../out/attn-img-005.jpg saved.
../out/attn-img-006.jpg saved.
../out/attn-img-007.jpg saved.
../out/attn-img-008.jpg saved.
../out/attn-img-009.jpg saved.
../out/attn-img-010.jpg saved.
../out/attn-img-011.jpg saved.
../out/attn-img-012.jpg saved.
../out/attn-img-013.jpg saved.
../out/attn-img-014.jpg saved.
../out/attn-img-015.jpg saved.
../out/attn-img-016.jpg saved.
../out/attn-img-017.jpg saved.
../out/attn-img-018.jpg saved.
../out/attn-img-019.jpg saved.
../out/attn-img-020.jpg saved.
../out/attn-img-021.jpg saved.
../out/attn-img-022.jpg saved.
../out/attn-img-023.jpg saved.
../out/attn-img-024.jpg saved.
../out/attn-img-025.jpg saved.
../out/attn-img-026.jpg saved.
../out/attn-img-027.jpg saved.
../out/attn-img-028.jpg saved.
../out/attn-img-029.jpg saved.
../out/attn-img-030.jpg saved.
../out/attn-img-031.jpg saved.
../out/attn-img-032.jpg saved.
../out/a

## Output images to video

Input video is 60 fps, change if yours is different

In [23]:
!ffmpeg -framerate 60 -i ../output/attn-image-%03d.jpg ../output.mp4

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab

If you want both input and output videos side by side:

In [24]:
!ffmpeg -i ../video.mp4 -i ../output.mp4 -filter_complex '[0:v]pad=iw*2:ih[int];[int][1:v]overlay=W/2:0[vid]' -map '[vid]' -c:v libx264 -crf 23 -preset veryfast final.mp4

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enab