In [None]:
%pip install pandas scikit-learn Pillow numpy tqdm rembg matplotlib onnxruntime transformers torch torchvision torchaudio
%pip install kornia timm # RMBG-2.0

In [None]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
from tqdm import tqdm
import pandas as pd
from glob import glob
import PIL
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation

In [None]:
# Paths to directories
dam_dir = 'data/DAM'
test_dir = 'data/test_image_headmind'
extensions = ['*.jpg', '*.jpeg', '*.png']

# Get list of image file paths for DAM and Test
dam_images = glob(os.path.join(dam_dir, '*.jpeg'))
dam_images.sort()

test_images = []
for ext in extensions:
    pattern = os.path.join(test_dir, ext)
    test_images.extend(glob(pattern))
test_images.sort()

# Create DataFrames
dam_df = pd.DataFrame({'image_path': dam_images})
test_df = pd.DataFrame({'image_path': test_images})

print("DAM DataFrame:")
print(dam_df.head())

print("\nTest DataFrame:")
print(test_df.head())


In [None]:
from utils.preprocessing import preprocess_image
import importlib
importlib.reload(utils.preprocessing)

import matplotlib.pyplot as plt
from PIL import Image

BASE_CACHE_DIR = '.cache'

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
print(f"Using device: {device}")

sample_test_path = test_df['image_path'].iloc[60]
print(f"Sample test path: {sample_test_path}")
extracted_object = preprocess_image(sample_test_path, background_removal="RMBG_2")

if extracted_object is None:
    print("No object found in the image at path:", sample_test_path)

# Display the original and extracted object
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(PIL.ImageOps.exif_transpose(Image.open(sample_test_path)))
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(extracted_object)
plt.title("Extracted Object")
plt.axis('off')
plt.show()

# for i in tqdm(range(len(dam_df))):
#     sample_test_path = dam_df['image_path'].iloc[i]
#     extracted_object = preprocess_image(sample_test_path)

#     if extracted_object is None:
#         print("No object found in the image at path:", sample_test_path)
#         continue


In [None]:
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import os
import tensorflow as tf
from tensorflow.keras.preprocessing import image as keras_image
from sklearn.metrics.pairwise import cosine_similarity
import importlib
import json

import utils.models
importlib.reload(utils.models)
importlib.reload(utils.models.base_model)
importlib.reload(utils.models.dinov2_model)
importlib.reload(utils.models.facebook_vitmsn_model)
importlib.reload(utils.models.google_vit_model)
importlib.reload(utils.models.microsoft_resnet_model)
importlib.reload(utils.models.openai_clip_model)
import utils.models
from utils.models import DinoV2Model, FacebookViTMSNModel, GoogleViTModel, MicrosoftResNetModel, OpenAIClipModel


BACKGROUND_REMOVAL_METHOD = 'RMBG_2' # 'rembg', 'RMBG_2'
AUGMENT_WITH_3D_MODEL = False

SELECTED_MODEL = GoogleViTModel
EMBEDDING_AGGREGRATION_METHOD = "mean" # None, 'mean'

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow warnings

# Select the desired model by key
selected_model_key = SELECTED_MODEL.__name__
model = SELECTED_MODEL()

# Define embeddings file path
embeddings_file = os.path.join(BASE_CACHE_DIR, "embeddings", f'dam_features-{selected_model_key}-{BACKGROUND_REMOVAL_METHOD}.pkl')

if os.path.exists(embeddings_file):
    with open(embeddings_file, 'rb') as f:
        dam_features = pickle.load(f)
    print(f"Loaded precomputed DAM features from {embeddings_file}.")
else:
    dam_features = {}
    for idx, row in tqdm(dam_df.iterrows(), total=len(dam_df), desc="Extracting DAM features"):
        img_path = row['image_path']
        # Convert BGR to RGB for consistency
        img = preprocess_image(img_path, BACKGROUND_REMOVAL_METHOD)
        if img is None:
            # Skip if no object found
            continue
        feat = model.extract_features(img)
        dam_features[img_path] = feat
    os.makedirs(os.path.dirname(embeddings_file), exist_ok=True)
    with open(embeddings_file, 'wb') as f:
        pickle.dump(dam_features, f)
    print(f"Extracted features for DAM images and saved to {embeddings_file}.")
    
# Augment the data with the 3D model features

if AUGMENT_WITH_3D_MODEL:
    embeddings_file_3d = os.path.join(BASE_CACHE_DIR, "embeddings", f'dam_features-{selected_model_key}-rembg-3d.pkl')
    
    if os.path.exists(embeddings_file_3d):
        with open(embeddings_file_3d, 'rb') as f:
            dam_features_3d = pickle.load(f)
        print(f"Loaded precomputed DAM features (3D) from {embeddings_file_3d}.")
    else:
        dam_features_3d = {}
        for idx, row in tqdm(dam_df.iterrows(), total=len(dam_df), desc="Extracting DAM features (3D)"):
            img_path = row['image_path']
            # Convert BGR to RGB for consistency
            dam_id = os.path.basename(img_path).split('.')[0]
            
            for i in range(1, 9):
                img_path_3d = f"{BASE_CACHE_DIR}/TRELLIS/{dam_id}/{dam_id}-{i}.jpeg"
                if not os.path.exists(img_path_3d):
                    continue
                img = preprocess_image(img_path_3d, "rembg")
                if img is None:
                    # Skip if no object found
                    continue
                feat = model.extract_features(img)
                dam_features_3d[img_path_3d] = feat
        os.makedirs(os.path.dirname(embeddings_file_3d), exist_ok=True)
        with open(embeddings_file_3d, 'wb') as f:
            pickle.dump(dam_features_3d, f)
        print(f"Extracted features for DAM images (3D) and saved to {embeddings_file_3d}.")
    
    dam_features = dam_features | dam_features_3d

In [None]:
import pandas as pd

# Load the labels CSV into a DataFrame
labels_df = pd.read_csv('labels/handmade_test_labels.csv')

# Create a dictionary mapping each test image filename to a list of reference labels
labels_dict = {}
for _, row in labels_df.iterrows():
    image_name = row['image'].strip()
    # Split the reference column by '/' to handle multiple labels, remove empty strings and spaces
    references = [ref.strip() for ref in str(row['reference']).split('/') if ref.strip() and ref.strip() != '?']
    labels_dict[image_name] = references

# For debugging, you can print a sample of the dictionary
print(dict(list(labels_dict.items())[:5]))

In [None]:
# Aggregate the embedding of an image by averaging the embeddings of its patches
def aggregate_embedding(embedding):
    if len(embedding.shape) > 2:
        embedding = embedding.squeeze()
        if EMBEDDING_AGGREGRATION_METHOD == 'mean':
            embedding = np.mean(embedding, axis=0).reshape(1, -1)
    else:
        embedding = embedding.reshape(1, -1)
        
    return embedding

In [None]:
import math
import time

# Initialize accuracy counters
total_queries = 0
correct_top1 = 0
correct_top3 = 0
correct_top5 = 0

# This dictionary will store the index of the correct match in each query if found
correct_match_indices = {}

# Set the desired number of top matches to display
top_n = 10
max_height = 5

results = []

for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Processing test images"):
    test_path = row['image_path']
    
    # if not "image-".lower() in test_path.lower():
    #     continue
    
    total_queries += 1
    
    t = time.time()
    # Preprocess and extract object from test image
    # Adjust this line to use your actual extraction method
    # test_obj = original_image
    test_obj = preprocess_image(test_path)
    print(f"Preprocessing time: {time.time() - t:.4f} seconds")
    
    t = time.time()
    
    # Extract features for the test object
    test_feat = model.extract_features(test_obj)
    
    print(f"Feature extraction time: {time.time() - t:.4f} seconds")
    t = time.time()
    
    test_feat = aggregate_embedding(test_feat)
    
    # Compute similarities with all DAM features
    similarities = []
    for dam_path, dam_feat in dam_features.items():
        dam_feat = aggregate_embedding(dam_feat)
        
        # Compute pairwise cosine similarities between tokens
        pairwise_sim = cosine_similarity(test_feat, dam_feat)  # Shape: (196, 196)
        
        # Aggregate similarity (e.g., mean of pairwise similarities)
        sim = pairwise_sim.mean()  # You can use mean, max, or sum as aggregation
        similarities.append((dam_path, sim))
        
        # Sort and retrieve the top `top_n` matches
        sorted_matches = sorted(similarities, key=lambda x: x[1], reverse=True)

    # Store results
    results.append({
        'test_image': test_path,
        'top_matches': sorted_matches
    })
    
    # Retrieve true reference labels for the current test image
    test_image_name = os.path.basename(test_path)
    true_references = labels_dict.get(test_image_name, [])
    
    # Check positions of correct matches in the top results
    found_indices = []  # list to store indices of correct matches found in top_n
    top_n_matches_codes = []
    for i, (match_path, sim_score) in enumerate(sorted_matches):
        predicted_code = os.path.splitext(os.path.basename(match_path))[0]
        if "-" in predicted_code:
            predicted_code = predicted_code.split("-")[0]
        if predicted_code in true_references:
            found_indices.append(len(top_n_matches_codes))  # store zero-based index of the correct match
            
        top_n_matches_codes.append(predicted_code)
        if len(top_n_matches_codes) >= top_n:
            break

    # If we found any correct match, record the smallest index (closest to top)
    if found_indices:
        correct_index = min(found_indices)
        correct_match_indices[test_image_name] = correct_index
        # Update counters for top-1, top-3, and top-5
        if correct_index < 1:
            correct_top1 += 1
        if correct_index < 3:
            correct_top3 += 1
        if correct_index < 5:
            correct_top5 += 1
    
    # Determine number of columns for grid specification
    ncols = min(max(2, top_n), 5)  # Ensure at least 2 columns for original and extracted images
    nrows = math.ceil(top_n / ncols) + 1  # Add 1 for the original and extracted images
    
    # Create subplots with 2 rows and `ncols` columns
    fig = plt.figure(constrained_layout=True, figsize=(4 * ncols, max_height * nrows))
    gs = fig.add_gridspec(nrows, ncols)
    
    # Subplot for original image (row 0, col 0)
    ax_orig = fig.add_subplot(gs[0, 0])
    ax_orig.imshow(PIL.ImageOps.exif_transpose(Image.open(test_path)))
    ax_orig.set_title("Original Image")
    ax_orig.axis('off')
    
    # Subplot for extracted object (row 0, col 1)
    ax_ext = fig.add_subplot(gs[0, 1])
    ax_ext.imshow(test_obj)
    ax_ext.set_title("Extracted Object")
    ax_ext.axis('off')
    
    # Leave remaining cells in the first row empty if any
    for col in range(2, ncols):
        ax_empty = fig.add_subplot(gs[0, col])
        ax_empty.axis('off')
    
    # Display top `top_n` matches in the second row
    for i, (match_path, sim_score) in enumerate(sorted_matches[:top_n]):
        ax = fig.add_subplot(gs[i // ncols + 1, i % ncols])
        img_match = Image.open(match_path)
        img_match = PIL.ImageOps.exif_transpose(img_match)
        ax.imshow(img_match)
        ax.set_title(f"Match {i+1}\nSim: {sim_score:.4f}")
        ax.axis('off')
    
    plt.show()
    
    # After processing all test images, compute and display overall accuracy metrics
    if total_queries > 0:
        accuracy_top1 = correct_top1 / total_queries
        accuracy_top3 = correct_top3 / total_queries
        accuracy_top5 = correct_top5 / total_queries

        print(f"Total queries processed: {total_queries}")
        print(f"Top-1 Accuracy: {accuracy_top1:.2%} ({correct_top1} correct)")
        print(f"Top-3 Accuracy: {accuracy_top3:.2%} ({correct_top3} correct)")
        print(f"Top-5 Accuracy: {accuracy_top5:.2%} ({correct_top5} correct)")
    else:
        print("No queries were processed.")

    # Optionally print the indices where correct matches were found for each image
    if test_image_name in correct_match_indices:
        print(f"Correct match found at top-{correct_match_indices[test_image_name]+1}")
        
    # Display match information in console
    tqdm.write(f"Processed {test_path} -> Top {top_n} matches:")
    for match in sorted_matches[:top_n]:
        tqdm.write(f"\tMatch: {match[0]}, Similarity: {match[1]:.4f}")

filename = f"{type(model).__name__}-{BACKGROUND_REMOVAL_METHOD}-{EMBEDDING_AGGREGRATION_METHOD}"
if AUGMENT_WITH_3D_MODEL:
    filename += "_3d"
filename += ".json"

benchmark = {}

for match in sorted_matches[:top_n]:
    benchmark["top_1_accuracy"] = accuracy_top1
    benchmark["top_3_accuracy"] = accuracy_top3
    benchmark["top_5_accuracy"] = accuracy_top5

os.makedirs("benchmarks", exist_ok=True)
with open(os.path.join("benchmarks", filename), 'w') as f:
    json.dump(benchmark, f)