In [11]:
pip install transformers

Note: you may need to restart the kernel to use updated packages.


Step 1 – Imports

In [12]:
# Step 1 - Imports

import os                      # work with file paths and directories
import pickle                  # save and load Python dictionaries
import shutil                  # remove folders like .ipynb_checkpoints

import torch                   # main deep learning framework
from PIL import Image          # load and handle images

from transformers import AutoImageProcessor, ViTMAEModel  # MAE image processor and model


Step 2 — Load MAE ViT-H (Huge) model

In [13]:
# Step 2 - Load MAE ViT-H (Huge) model

mae_model_name = "facebook/vit-mae-huge"                        # model name for MAE ViT-H

image_processor = AutoImageProcessor.from_pretrained(mae_model_name)  # create preprocessor
mae_model = ViTMAEModel.from_pretrained(mae_model_name)               # load pretrained MAE model
mae_model.eval()                                                       # set to evaluation mode


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


ViTMAEModel(
  (embeddings): ViTMAEEmbeddings(
    (patch_embeddings): ViTMAEPatchEmbeddings(
      (projection): Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14))
    )
  )
  (encoder): ViTMAEEncoder(
    (layer): ModuleList(
      (0-31): 32 x ViTMAELayer(
        (attention): ViTMAEAttention(
          (attention): ViTMAESelfAttention(
            (query): Linear(in_features=1280, out_features=1280, bias=True)
            (key): Linear(in_features=1280, out_features=1280, bias=True)
            (value): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (output): ViTMAESelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTMAEIntermediate(
          (dense): Linear(in_features=1280, out_features=5120, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTMAEOutput(
          (dense)

Step 3 – Show model summary

In [14]:
# Step 3 - Show MAE ViT-Large model summary

print(mae_model)   # print architecture and transformer layers


ViTMAEModel(
  (embeddings): ViTMAEEmbeddings(
    (patch_embeddings): ViTMAEPatchEmbeddings(
      (projection): Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14))
    )
  )
  (encoder): ViTMAEEncoder(
    (layer): ModuleList(
      (0-31): 32 x ViTMAELayer(
        (attention): ViTMAEAttention(
          (attention): ViTMAESelfAttention(
            (query): Linear(in_features=1280, out_features=1280, bias=True)
            (key): Linear(in_features=1280, out_features=1280, bias=True)
            (value): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (output): ViTMAESelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTMAEIntermediate(
          (dense): Linear(in_features=1280, out_features=5120, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTMAEOutput(
          (dense)

Step 4 – Prepare feature storage and image directory

In [15]:
# Step 4 - Prepare feature storage and image directory

features = {}                                              # dictionary to store extracted feature vectors
directory = r"test - front masked images - side"                          # folder containing your input images


Step 5 – Remove .ipynb_checkpoints folder

In [16]:
# Step 5 - Remove any '.ipynb_checkpoints' folder inside the image folder

folder = "test - front masked images - side"                               # folder to clean

for item in os.listdir(folder):                              # loop through folder contents
    path = os.path.join(folder, item)                        # full path to item
    if item == ".ipynb_checkpoints" and os.path.isdir(path): # check for checkpoint folder
        shutil.rmtree(path)                                  # delete the folder
        print("Removed:", path)                              # confirm deletion


Step 6 — Extract features with MAE ViT-H

In [17]:
# Step 6 - Extract features using MAE ViT-H (Huge)

valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".gif"}  # valid image extensions

file_list = os.listdir(directory)                             # list files
print(f"Found {len(file_list)} items in directory: {directory}")

for idx, image_name in enumerate(file_list, start=1):          # loop with index
    image_path = os.path.join(directory, image_name)           # full image path

    if not os.path.isfile(image_path):                         # skip non-files
        continue

    ext = os.path.splitext(image_name)[1].lower()              # file extension
    if ext not in valid_extensions:                            # skip invalid
        continue

    img = Image.open(image_path).convert("RGB")                # load image as RGB

    inputs = image_processor(images=img, return_tensors="pt")  # preprocess

    with torch.no_grad():                                      # no gradients needed
        outputs = mae_model(**inputs)                          # forward pass
        encoder_out = outputs.last_hidden_state                # shape: [1, tokens, dim]
        feature_vector = encoder_out.mean(dim=1)               # mean pool → [1, dim]

    feature_vector = feature_vector.cpu().numpy()              # convert to NumPy

    features[image_path] = feature_vector                      # save into dictionary

    print(f"Processed {idx}/{len(file_list)}: {image_name}")   # progress print

feature_vector.shape                                            # show vector dimension


Found 1684 items in directory: test - front masked images - side
Processed 1/1684: e5ae8fe5bbdf611a1e8d06e66e849bdf.png
Processed 2/1684: 605a5fd09058c48156b0ef518b63b2de.png
Processed 3/1684: 909c9277309e13ee014e347603aba620.png
Processed 4/1684: bef6a68bc8dd475c124f6de2413385d3.png
Processed 5/1684: 6d7ed4bc4a17546447efed0ca6e2ff11.png
Processed 6/1684: 4c12d6a82cb0a75ee556a54ab1afc21e.png
Processed 7/1684: dd27be8b3d6b9c2c14c14318612ba0dc.png
Processed 8/1684: fc65b84d8183e3a872785b4e2eecaa66.png
Processed 9/1684: 851f712c7cfc6b62b20b6f8cba65c20a.png
Processed 10/1684: c9ede0a19f8e79ec7a4cd8f126129f2d.png
Processed 11/1684: 652500aa90597ed06ccc8f15bc9b83aa.png
Processed 12/1684: b55e0664c7c1642cd015b5585a8d5fd3.png
Processed 13/1684: 67194d8a47331d1f722db4f737546021.png
Processed 14/1684: 719fc866000f5edb56700ee0755ea109.png
Processed 15/1684: 46c5d36f10af6cf0fd2b4a02eb9a9add.png
Processed 16/1684: e1a3c3b6df492a58fde15be36a8371bc.png
Processed 17/1684: 7ac05db49a083076cb5ed077f16b9

(1, 1280)

In [18]:
feature_vector.shape

(1, 1280)

Step 7 – Rename keys to filenames only and save as a pickle file

In [19]:
# Step 7 - Rename keys to filenames only and save the features dictionary to a pickle file

features_renamed = {}                                          # new dictionary

for full_path, vec in features.items():                        # loop through original features
    filename = os.path.basename(full_path)                     # extract filename
    features_renamed[filename] = vec                           # store under filename

features = features_renamed                                    # replace dictionary

print("Sample keys after rename:")                              # show sample keys
for i, k in enumerate(features.keys()):
    print(k)
    if i == 4:                                                 # print only first 5
        break

pickle_path = "mae_large_side_masked_features-testA.pkl"             # output filename

with open(pickle_path, "wb") as f:                             # open file for writing
    pickle.dump(features, f)                                   # save dictionary

print(f"Saved {len(features)} feature vectors to {pickle_path}")


Sample keys after rename:
e5ae8fe5bbdf611a1e8d06e66e849bdf.png
605a5fd09058c48156b0ef518b63b2de.png
909c9277309e13ee014e347603aba620.png
bef6a68bc8dd475c124f6de2413385d3.png
6d7ed4bc4a17546447efed0ca6e2ff11.png
Saved 1684 feature vectors to mae_large_side_masked_features-testA.pkl


Step 8 – Compare image filenames and dictionary keys

In [20]:
# Step 8 - Compare actual image filenames with keys in the features dictionary

image_files_in_dir = []                                         # list of valid images

for name in os.listdir(directory):                              # loop through folder
    full_path = os.path.join(directory, name)                   # full path
    ext = os.path.splitext(name)[1].lower()                     # file extension
    if os.path.isfile(full_path) and ext in valid_extensions:   # keep only images
        image_files_in_dir.append(name)

files_set = set(image_files_in_dir)                             # convert to set
keys_set = set(features.keys())                                 # dictionary keys

files_not_in_dict = files_set - keys_set                        # missing feature entries
keys_not_in_folder = keys_set - files_set                       # extra keys not matching folder

print(f"Number of image files in directory: {len(files_set)}")
print(f"Number of keys in features dictionary: {len(keys_set)}\n")

if not files_not_in_dict and not keys_not_in_folder:            # perfect match
    print("✅ All image filenames and dictionary keys MATCH exactly.")
else:
    print("⚠ Some mismatches were found:\n")

    if files_not_in_dict:
        print(f"Files in directory but NOT in dictionary ({len(files_not_in_dict)}):")
        for i, name in enumerate(sorted(files_not_in_dict)):
            print("  -", name)
            if i == 9:
                break
        print()

    if keys_not_in_folder:
        print(f"Keys in dictionary but NO corresponding file in directory ({len(keys_not_in_folder)}):")
        for i, name in enumerate(sorted(keys_not_in_folder)):
            print("  -", name)
            if i == 9:
                break


Number of image files in directory: 1684
Number of keys in features dictionary: 1684

✅ All image filenames and dictionary keys MATCH exactly.
