<a href="https://colab.research.google.com/github/graumannm/Berlin_Bike_CV/blob/main/s01_extract_embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### This notebook loads a model and runs a foward pass of all images in our created bike lane dataset to extract the embeddings from a pre-defined model layer

In [1]:
# define model and dataset names
model_name = 'dinoS14'
dataset = 'bikelanes'

# Setup

In [2]:
from google.colab import drive
from torchvision import transforms
import torch
import time
from PIL import Image
import skimage.io as io
import pickle
import os
import numpy as np

In [3]:
# connect to drive
drive.mount('/gdrive')

Mounted at /gdrive


# Function for picking a model and its corresponding layer

In [5]:
# function for picking a model
def pick_model(model_name):

    if model_name == 'resnet':

      print('loading ResNet')

      # load model
      model_path = '/gdrive/MyDrive/berlin_bike_CV/CobblestoneModel/finetuned_ResNet101.pt'
      model = torch.load(model_path)

      # select layer for feature extraction
      my_layer = model.avgpool

    elif model_name == 'mask2former':

      print('loading Mask2Former') # image processing for mask2former not implemented yet

      # load model
      !pip install -q git+https://github.com/huggingface/transformers.git
      from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
      model = Mask2FormerForUniversalSegmentation.from_pretrained("facebook/mask2former-swin-large-mapillary-vistas-semantic")

      # select layer for feature extraction
      my_layer = model.model.pixel_level_module.encoder.encoder.layers[3].blocks[1].output

    elif model_name == 'dinoS14':

      print('loading DinoV2 S14')

      # load model
      model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')

      # select layer for feature extraction
      my_layer = model.head

    elif model_name == 'dinoG14':

      print('loading DinoV2 G14')

      # load model
      model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')

      # select layer for feature extraction
      my_layer = model.head

    return model, my_layer

In [6]:
# define image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224), # new size will be 3x224x224
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
        )
    ])

In [7]:
# load model and define layer for hook
model, my_layer = pick_model(model_name)

# put model in evaluation mode for consistent results
model.eval()

# deactivate gradients
for param in model.parameters():
    param.requires_grad = False

# create hook on my_layer to get features
features = []
def hook(module, input, output):
    features.append(output)

loading DinoV2 S14


Downloading: "https://github.com/facebookresearch/dinov2/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dinov2_vits14_pretrain.pth
100%|██████████| 84.2M/84.2M [00:04<00:00, 20.1MB/s]


# Run foward pass on all image to extract embeddings

In [9]:
# get the start time
st = time.time()

# define data paths
img_dir  = '/gdrive/My Drive/berlin_bike_CV/final_project_first_images/labelled images/'
# initialize empty dict
img_embedding  = {}
full_img_paths = {}
counter = 0

image_files = image_files = os.listdir(img_dir + 'bikelanes/')

for i, image_file in enumerate(image_files):

    # progress report
    print('Image #', i+1, ' out of ',len(image_files))

    # load current image as PIL.Image.Image
    img = Image.open(img_dir + 'bikelanes/' + image_files[i])

    # transform image
    img_t = transform(img)

    # run foward pass
    img_unsqueezed = torch.unsqueeze(img_t, 0) # add first singleton dimension, the 'batch'
    features = []  # Reset features list for each image
    hook_handle = my_layer.register_forward_hook(hook)  # Register the hook for the current layer
    out = model(img_unsqueezed)
    hook_handle.remove()  # Remove the hook after extracting features

    # extract features, convert to np array and save with file name as key in dictionary
    img_embedding[image_files[i]] = torch.squeeze(torch.flatten(features[0])).numpy()

    # make another dictionary with full path as value
    full_img_paths[image_files[i]] = img_dir + 'bikelanes/' + image_files[i]

# get the end time
et = time.time()

# get the execution time
elapsed_time = et - st
print('Execution time:', elapsed_time, 'seconds')

Image # 1  out of  242
Image # 2  out of  242
Image # 3  out of  242
Image # 4  out of  242
Image # 5  out of  242
Image # 6  out of  242
Image # 7  out of  242
Image # 8  out of  242
Image # 9  out of  242
Image # 10  out of  242
Image # 11  out of  242
Image # 12  out of  242
Image # 13  out of  242
Image # 14  out of  242
Image # 15  out of  242
Image # 16  out of  242
Image # 17  out of  242
Image # 18  out of  242
Image # 19  out of  242
Image # 20  out of  242
Image # 21  out of  242
Image # 22  out of  242
Image # 23  out of  242
Image # 24  out of  242
Image # 25  out of  242
Image # 26  out of  242
Image # 27  out of  242
Image # 28  out of  242
Image # 29  out of  242
Image # 30  out of  242
Image # 31  out of  242
Image # 32  out of  242
Image # 33  out of  242
Image # 34  out of  242
Image # 35  out of  242
Image # 36  out of  242
Image # 37  out of  242
Image # 38  out of  242
Image # 39  out of  242
Image # 40  out of  242
Image # 41  out of  242
Image # 42  out of  242
I

# Save embeddings and corresponding image paths for later use

In [12]:
# save embeddigs as pickle in main folder
file_name = img_dir + '/' + model_name + '_' + dataset + "_embeddings.pickle"

# Save the dictionary as a pickle file
with open(file_name, "wb") as file:
    pickle.dump(img_embedding, file)

In [13]:
# save full file paths as pickle too
file_name = img_dir + '/' + model_name + '_' + dataset +"_paths.pickle"

# Save the dictionary as a pickle file
with open(file_name, "wb") as file:
    pickle.dump(full_img_paths, file)