In [None]:
pip install transformers

Step 1 – Imports

In [None]:
# Step 1 - Imports

import os                      # work with file paths and directories
import pickle                  # save and load Python dictionaries
import shutil                  # remove folders like .ipynb_checkpoints

import torch                   # main deep learning framework
from PIL import Image          # load and handle images

from transformers import AutoImageProcessor, ViTMAEModel  # MAE image processor and model


Step 2 — Load MAE ViT-H (Huge) model

In [None]:
# Step 2 - Load MAE ViT-H (Huge) model

mae_model_name = "facebook/vit-mae-huge"                        # model name for MAE ViT-H

image_processor = AutoImageProcessor.from_pretrained(mae_model_name)  # create preprocessor
mae_model = ViTMAEModel.from_pretrained(mae_model_name)               # load pretrained MAE model
mae_model.eval()                                                       # set to evaluation mode


Step 3 – Show model summary

In [None]:
# Step 3 - Show MAE ViT-Large model summary

print(mae_model)   # print architecture and transformer layers


Step 4 – Prepare feature storage and image directory

In [None]:
# Step 4 - Prepare feature storage and image directory

features = {}                                              # dictionary to store extracted feature vectors
directory = r"side masked images"                          # folder containing your input images


Step 5 – Remove .ipynb_checkpoints folder

In [None]:
# Step 5 - Remove any '.ipynb_checkpoints' folder inside the image folder

folder = "side masked images"                               # folder to clean

for item in os.listdir(folder):                              # loop through folder contents
    path = os.path.join(folder, item)                        # full path to item
    if item == ".ipynb_checkpoints" and os.path.isdir(path): # check for checkpoint folder
        shutil.rmtree(path)                                  # delete the folder
        print("Removed:", path)                              # confirm deletion


Step 6 — Extract features with MAE ViT-H

In [None]:
# Step 6 - Extract features using MAE ViT-H (Huge)

valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".gif"}  # valid image extensions

file_list = os.listdir(directory)                             # list files
print(f"Found {len(file_list)} items in directory: {directory}")

for idx, image_name in enumerate(file_list, start=1):          # loop with index
    image_path = os.path.join(directory, image_name)           # full image path

    if not os.path.isfile(image_path):                         # skip non-files
        continue

    ext = os.path.splitext(image_name)[1].lower()              # file extension
    if ext not in valid_extensions:                            # skip invalid
        continue

    img = Image.open(image_path).convert("RGB")                # load image as RGB

    inputs = image_processor(images=img, return_tensors="pt")  # preprocess

    with torch.no_grad():                                      # no gradients needed
        outputs = mae_model(**inputs)                          # forward pass
        encoder_out = outputs.last_hidden_state                # shape: [1, tokens, dim]
        feature_vector = encoder_out.mean(dim=1)               # mean pool → [1, dim]

    feature_vector = feature_vector.cpu().numpy()              # convert to NumPy

    features[image_path] = feature_vector                      # save into dictionary

    print(f"Processed {idx}/{len(file_list)}: {image_name}")   # progress print

feature_vector.shape                                            # show vector dimension


In [None]:
feature_vector.shape

Step 7 – Rename keys to filenames only and save as a pickle file

In [None]:
# Step 7 - Rename keys to filenames only and save the features dictionary to a pickle file

features_renamed = {}                                          # new dictionary

for full_path, vec in features.items():                        # loop through original features
    filename = os.path.basename(full_path)                     # extract filename
    features_renamed[filename] = vec                           # store under filename

features = features_renamed                                    # replace dictionary

print("Sample keys after rename:")                              # show sample keys
for i, k in enumerate(features.keys()):
    print(k)
    if i == 4:                                                 # print only first 5
        break

pickle_path = "mae_large_side_masked_features.pkl"             # output filename

with open(pickle_path, "wb") as f:                             # open file for writing
    pickle.dump(features, f)                                   # save dictionary

print(f"Saved {len(features)} feature vectors to {pickle_path}")


Step 8 – Compare image filenames and dictionary keys

In [None]:
# Step 8 - Compare actual image filenames with keys in the features dictionary

image_files_in_dir = []                                         # list of valid images

for name in os.listdir(directory):                              # loop through folder
    full_path = os.path.join(directory, name)                   # full path
    ext = os.path.splitext(name)[1].lower()                     # file extension
    if os.path.isfile(full_path) and ext in valid_extensions:   # keep only images
        image_files_in_dir.append(name)

files_set = set(image_files_in_dir)                             # convert to set
keys_set = set(features.keys())                                 # dictionary keys

files_not_in_dict = files_set - keys_set                        # missing feature entries
keys_not_in_folder = keys_set - files_set                       # extra keys not matching folder

print(f"Number of image files in directory: {len(files_set)}")
print(f"Number of keys in features dictionary: {len(keys_set)}\n")

if not files_not_in_dict and not keys_not_in_folder:            # perfect match
    print("✅ All image filenames and dictionary keys MATCH exactly.")
else:
    print("⚠ Some mismatches were found:\n")

    if files_not_in_dict:
        print(f"Files in directory but NOT in dictionary ({len(files_not_in_dict)}):")
        for i, name in enumerate(sorted(files_not_in_dict)):
            print("  -", name)
            if i == 9:
                break
        print()

    if keys_not_in_folder:
        print(f"Keys in dictionary but NO corresponding file in directory ({len(keys_not_in_folder)}):")
        for i, name in enumerate(sorted(keys_not_in_folder)):
            print("  -", name)
            if i == 9:
                break
