# Deep ViT Features - Co-segmentation
Given a set of images, segment all the common objects among them.

In [1]:
#@title Installations and mounting
# !pip install tqdm
# !pip install faiss-cpu
# !pip install timm
# !pip install opencv-python
# !pip install git+https://github.com/lucasb-eyer/pydensecrf.git
# !git clone https://github.com/ShirAmir/dino-vit-features.git
# import sys
# sys.path.append('dino-vit-features')
%cd ~/dev/dino-vit-features
%load_ext autoreload
%autoreload 2

/users/r/ramzie/dev/dino-vit-features


## Change Runtime Type
To get a GPU in Google Colab, go to the top menu: Runtime ➔ Change runtime type and select GPU as Hardware accelerator.


In [26]:
#@title Configuration:
#@markdown Choose image paths:
images_paths = ['dino-vit-features/images/cat.jpg', 'dino-vit-features/images/ibex.jpg'] #@param
images_paths = ['images/cat.jpg', 'images/ibex.jpg'] #@param
#@markdown Choose loading size:
load_size = 360 #@param
#@markdown Choose layer of descriptor:
layer = 11 #@param
#@markdown Choose facet of descriptor:
facet = 'key' #@param
#@markdown Choose if to use a binned descriptor:
bin=False #@param
#@markdown Choose fg / bg threshold:
thresh=0.065 #@param
#@markdown Choose model type:
model_type='vit_small_patch8_224' #@param
#@markdown Choose stride:
stride=4 #@param
#@markdown Choose elbow coefficient for setting number of clusters
elbow=0.975 #@param
#@markdown Choose percentage of votes to make a cluster salient.
votes_percentage=75 #@param
#@markdown Choose whether to remove outlier images
remove_outliers=False #@param
#@markdown Choose threshold to distinguish inliers from outliers
outliers_thresh=0.7 #@param
#@markdown Choose interval for sampling descriptors for training
sample_interval=100 #@param
#@markdown Use low resolution saliency maps -- reduces RAM usage.
low_res_saliency_maps=True #@param

In [28]:
import matplotlib.pyplot as plt
import torch
from cosegmentation import find_cosegmentation, draw_cosegmentation, draw_cosegmentation_binary_masks

with torch.no_grad():

     # computing cosegmentation
    seg_masks, pil_images = find_cosegmentation(images_paths, elbow, load_size, layer, facet, bin, thresh, model_type,
                                                stride, votes_percentage, sample_interval, remove_outliers,
                                                outliers_thresh, low_res_saliency_maps)

    figs, axes = [], []
    for pil_image in pil_images:
      fig, ax = plt.subplots()
      ax.axis('off')
      ax.imshow(pil_image)
      figs.append(fig)
      axes.append(ax)
    
    # saving cosegmentations
    binary_mask_figs = draw_cosegmentation_binary_masks(seg_masks)
    chessboard_bg_figs = draw_cosegmentation(seg_masks, pil_images)

    plt.show()

TypeError: unsupported operand type(s) for //: 'tuple' and 'int'

In [12]:
import timm

In [19]:
timm.list_models('*vi*')

['convit_base',
 'convit_small',
 'convit_tiny',
 'vit_base_r26_s32_224',
 'vit_base_r50_s16_224',
 'vit_base_r50_s16_224_in21k',
 'vit_base_r50_s16_384',
 'vit_base_resnet26d_224',
 'vit_base_resnet50_224_in21k',
 'vit_base_resnet50_384',
 'vit_base_resnet50d_224',
 'vit_large_r50_s32_224',
 'vit_large_r50_s32_224_in21k',
 'vit_large_r50_s32_384',
 'vit_small_r26_s32_224',
 'vit_small_r26_s32_224_in21k',
 'vit_small_r26_s32_384',
 'vit_small_resnet26d_224',
 'vit_small_resnet50d_s16_224',
 'vit_tiny_r_s16_p8_224',
 'vit_tiny_r_s16_p8_224_in21k',
 'vit_tiny_r_s16_p8_384']

In [9]:
timm.__version__

'0.4.12'

In [10]:
pip install -U timm

Collecting timm
  Using cached timm-0.6.12-py3-none-any.whl (549 kB)
Collecting huggingface-hub
  Using cached huggingface_hub-0.12.0-py3-none-any.whl (190 kB)
Collecting filelock
  Using cached filelock-3.9.0-py3-none-any.whl (9.7 kB)
Collecting requests
  Using cached requests-2.28.2-py3-none-any.whl (62 kB)
Collecting urllib3<1.27,>=1.21.1
  Using cached urllib3-1.26.14-py2.py3-none-any.whl (140 kB)
Collecting charset-normalizer<4,>=2
  Downloading charset_normalizer-3.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (198 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m198.8/198.8 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: charset-normalizer, urllib3, filelock, requests, huggingface-hub, timm
  Attempting uninstall: timm
    Found existing installation: timm 0.4.12
    Uninstalling timm-0.4.12:
      Successfully uninstalled timm-0.4.12
Successfully installed charset-normalizer-3.0.1 filelock-3.9.0 huggingface-hu

In [23]:
timm.create_model('vit_small_patch8_224.dino')

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(8, 8), stride=(8, 8))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
      (l

In [21]:
%pip install git+https://github.com/rwightman/pytorch-image-models.git

Collecting git+https://github.com/rwightman/pytorch-image-models.git
  Cloning https://github.com/rwightman/pytorch-image-models.git to /tmp/pip-req-build-vux6nwej
  Running command git clone --filter=blob:none --quiet https://github.com/rwightman/pytorch-image-models.git /tmp/pip-req-build-vux6nwej
  Resolved https://github.com/rwightman/pytorch-image-models.git to commit 709d5e0d9d2d3f501531506eda96a435737223a3
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: timm
  Building wheel for timm (setup.py) ... [?25ldone
[?25h  Created wheel for timm: filename=timm-0.8.11.dev0-py3-none-any.whl size=1989971 sha256=0351e91465ca08f74148d45d2265012f45ec76940377b04e3a3ae82064ec6ed5
  Stored in directory: /tmp/pip-ephem-wheel-cache-fkv1dkb_/wheels/eb/1e/79/4dfc1bba276172378ab3e51ceed8e1e59ff8fba24e453244bd
Successfully built timm
Installing collected packages: timm
  Attempting uninstall: timm
    Found existing installation: timm 0.6.12
    Uninstalling ti