* ViT: instead of removing a layer manually,
    * Load a backbone-only ViT model (no classification head).
    * take the token embedding from the final transformer layer as the image feature vector

![image.png](attachment:b29e3ad2-15be-42d3-89b2-0be6a47703ea.png)

Step 1 – Imports

In [11]:
!pip install transformers



In [12]:
# Step 1 - Imports

import os  # work with folders and file paths
import pickle  # save and load feature dictionaries
import shutil  # remove unwanted folders

import torch  # main deep learning framework
from PIL import Image  # load images

from transformers import AutoImageProcessor, ViTModel  # ViT + preprocessor


Step 2 – Load the ViT model

In [13]:
# Step 2 - Load ViT model (PyTorch)

vit_model_name = "google/vit-base-patch16-224-in21k"  # ViT checkpoint
image_processor = AutoImageProcessor.from_pretrained(vit_model_name)  # image preprocessor
vit_model = ViTModel.from_pretrained(vit_model_name)  # load pretrained ViT

vit_model.eval()  # set model to evaluation mode


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (d

STEP 3 — Show Model Structure

In [14]:
# Step 3 - Show ViT model summary

print(vit_model)  # print architecture details


ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (d

STEP 4 — Prepare Feature Storage

In [15]:
# Step 4 - Prepare feature storage and image directory

features = {}  # dictionary to store features
directory = r"test - front masked images - side"  # your image folder


STEP 5 — Remove .ipynb_checkpoints

In [16]:
# Step 5 - Remove notebook checkpoint folders

folder = "test - front masked images - side"

for item in os.listdir(folder):  # loop through folder items
    path = os.path.join(folder, item)
    if item == ".ipynb_checkpoints" and os.path.isdir(path):
        shutil.rmtree(path)  # remove unwanted folder
        print("Removed:", path)


STEP 6 — Extract Features With ViT

In [17]:
# Step 6 - Extract features using ViT (PyTorch)

valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".gif"}  # allowed image types

file_list = os.listdir(directory)  # list folder items
print(f"Found {len(file_list)} files in directory: {directory}")

for idx, image_name in enumerate(file_list, start=1):

    image_path = os.path.join(directory, image_name)

    # skip non-files
    if not os.path.isfile(image_path):
        continue

    # check image extension
    ext = os.path.splitext(image_name)[1].lower()
    if ext not in valid_extensions:
        continue

    # load image with PIL
    img = Image.open(image_path).convert("RGB")

    # preprocess image for ViT
    inputs = image_processor(images=img, return_tensors="pt")

    # extract features (CLS token)
    with torch.no_grad():  # disables gradients
        outputs = vit_model(**inputs)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # CLS token

    feature_vector = cls_embedding.cpu().numpy()  # convert tensor → numpy

    # store in dictionary
    features[image_path] = feature_vector

    print(f"Processed {idx}/{len(file_list)}: {image_name}")


Found 1684 files in directory: test - front masked images - side
Processed 1/1684: e5ae8fe5bbdf611a1e8d06e66e849bdf.png
Processed 2/1684: 605a5fd09058c48156b0ef518b63b2de.png
Processed 3/1684: 909c9277309e13ee014e347603aba620.png
Processed 4/1684: bef6a68bc8dd475c124f6de2413385d3.png
Processed 5/1684: 6d7ed4bc4a17546447efed0ca6e2ff11.png
Processed 6/1684: 4c12d6a82cb0a75ee556a54ab1afc21e.png
Processed 7/1684: dd27be8b3d6b9c2c14c14318612ba0dc.png
Processed 8/1684: fc65b84d8183e3a872785b4e2eecaa66.png
Processed 9/1684: 851f712c7cfc6b62b20b6f8cba65c20a.png
Processed 10/1684: c9ede0a19f8e79ec7a4cd8f126129f2d.png
Processed 11/1684: 652500aa90597ed06ccc8f15bc9b83aa.png
Processed 12/1684: b55e0664c7c1642cd015b5585a8d5fd3.png
Processed 13/1684: 67194d8a47331d1f722db4f737546021.png
Processed 14/1684: 719fc866000f5edb56700ee0755ea109.png
Processed 15/1684: 46c5d36f10af6cf0fd2b4a02eb9a9add.png
Processed 16/1684: e1a3c3b6df492a58fde15be36a8371bc.png
Processed 17/1684: 7ac05db49a083076cb5ed077f16b9

STEP 7 — Convert Keys → Filenames Only + Save Pickle

In [18]:
# Step 7 - Rename feature keys and save pickle file

features_renamed = {}  # new dict storing features as filename keys

for full_path, vec in features.items():
    filename = os.path.basename(full_path)  # extract filename only
    features_renamed[filename] = vec  # save under filename key

features = features_renamed  # replace dictionary

# print sample keys
print("Sample keys:")
for i, k in enumerate(features.keys()):
    print(k)
    if i == 4:
        break

pickle_path = "vit-base_side_masked_features-testA.pkl"  # output file

with open(pickle_path, "wb") as f:
    pickle.dump(features, f)  # save dictionary

print(f"Saved {len(features)} feature vectors to {pickle_path}")


Sample keys:
e5ae8fe5bbdf611a1e8d06e66e849bdf.png
605a5fd09058c48156b0ef518b63b2de.png
909c9277309e13ee014e347603aba620.png
bef6a68bc8dd475c124f6de2413385d3.png
6d7ed4bc4a17546447efed0ca6e2ff11.png
Saved 1684 feature vectors to vit-base_side_masked_features-testA.pkl


In [19]:
feature_vector.shape

(1, 768)

STEP 8 — Compare Folder Files vs Feature Dictionary Keys

In [20]:
# Step 8 - Compare actual files vs dictionary keys

image_files_in_dir = []

for name in os.listdir(directory):
    full_path = os.path.join(directory, name)
    ext = os.path.splitext(name)[1].lower()
    if os.path.isfile(full_path) and ext in valid_extensions:
        image_files_in_dir.append(name)

files_set = set(image_files_in_dir)
keys_set = set(features.keys())

files_not_in_dict = files_set - keys_set
keys_not_in_folder = keys_set - files_set

print(f"Number of image files: {len(files_set)}")
print(f"Number of features saved: {len(keys_set)}\n")

if not files_not_in_dict and not keys_not_in_folder:
    print("✅ All images match dictionary keys.")
else:
    print("⚠ Mismatches found:\n")

    if files_not_in_dict:
        print(f"Files in folder but not in dictionary ({len(files_not_in_dict)}):")
        for i, name in enumerate(sorted(files_not_in_dict)):
            print("  -", name)
            if i == 9:
                break
        print()

    if keys_not_in_folder:
        print(f"Keys in dictionary without matching files ({len(keys_not_in_folder)}):")
        for i, name in enumerate(sorted(keys_not_in_folder)):
            print("  -", name)
            if i == 9:
                break


Number of image files: 1684
Number of features saved: 1684

✅ All images match dictionary keys.
