In [None]:
%pip install pandas scikit-learn Pillow numpy tqdm rembg matplotlib opencv-python onnxruntime transformers torch torchvision torchaudio

In [None]:
import tqdm

In [None]:
import os
import pandas as pd
from glob import glob

# Paths to directories
dam_dir = 'data/DAM'
test_dir = 'data/test_image_headmind'

# Get list of image file paths for DAM and Test
dam_images = glob(os.path.join(dam_dir, '*.jpeg'))
test_images = glob(os.path.join(test_dir, '*.jpg'))  # assuming .jpeg, adjust if .jpg

# Create DataFrames
dam_df = pd.DataFrame({'image_path': dam_images})
test_df = pd.DataFrame({'image_path': test_images})

print("DAM DataFrame:")
print(dam_df.head())

print("\nTest DataFrame:")
print(test_df.head())


In [None]:
from rembg import remove
import rembg
rembg_session = rembg.new_session('u2net')

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import io

def preprocess_image(input: Image.Image) -> Image.Image:
    """
    Preprocess the input image.
    """
    # if has alpha channel, use it directly; otherwise, remove background
    has_alpha = False
    if input.mode == 'RGBA':
        alpha = np.array(input)[:, :, 3]
        if not np.all(alpha == 255):
            has_alpha = True
    if has_alpha:
        output = input
    else:
        input = input.convert('RGB')
        max_size = max(input.size)
        scale = min(1, 1024 / max_size)
        if scale < 1:
            input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
        output = rembg.remove(input, session=rembg_session)
    output_np = np.array(output)
    alpha = output_np[:, :, 3]
    bbox = np.argwhere(alpha > 0.8 * 255)
    try:
        bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
    except ValueError:
        return None
        
    center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
    size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
    size = int(size * 1.2)
    bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
    output = output.crop(bbox)  # type: ignore
    output = output.resize((518, 518), Image.Resampling.LANCZOS)
    output = np.array(output).astype(np.float32) / 255
    
    # Set every pixel with alpha less than 0.8 to (255, 255, 255)
    output[output[:, :, 3] < 0.8] = [1, 1, 1, 0]
    output = output[:, :, :3]
    
    # Remove the alpha channel
    # output = output[:, :, :3] * output[:, :, 3:4]
    
    output = Image.fromarray((output * 255).astype(np.uint8))
    return output

sample_test_path = dam_df['image_path'].iloc[3]
print(f"Sample test path: {sample_test_path}")
extracted_object = preprocess_image(Image.open(sample_test_path))

if extracted_object is None:
    print("No object found in the image at path:", sample_test_path)

# Display the original and extracted object
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(Image.open(sample_test_path))
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(extracted_object)
plt.title("Extracted Object")
plt.axis('off')
plt.show()

# Example usage on a test image path
# for i in tqdm(range(len(dam_df))):
#     sample_test_path = dam_df['image_path'].iloc[i]
#     extracted_object = preprocess_image(Image.open(sample_test_path))

#     if extracted_object is None:
#         print("No object found in the image at path:", sample_test_path)
#         continue


In [None]:
import cv2
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import os
import tensorflow as tf
from tensorflow.keras.preprocessing import image as keras_image
from sklearn.metrics.pairwise import cosine_similarity

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow warnings

# For ViT integration
import torch
from transformers import ViTImageProcessor, ViTModel

# Define configurations for different models
model_configs = {
    "resnet50": {
        "type": "tf",
        "constructor": tf.keras.applications.ResNet50,
        "weights": "imagenet",
        "include_top": False,
        "pooling": "avg",
        "target_size": (224, 224),
        "preprocess_func": tf.keras.applications.resnet50.preprocess_input
    },
    "vit": {
        "type": "pt",
        "model_name": "google/vit-base-patch16-224-in21k",
        "target_size": (224, 224)  # ViT expects 224x224 images
    }
}

# Select the desired model by key
selected_model_key = "vit"  # Change to "resnet50" or "vit" as needed

# Retrieve configuration
config = model_configs[selected_model_key]

# Initialize model based on type
if config["type"] == "tf":
    base_model = config["constructor"](
        weights=config["weights"],
        include_top=config["include_top"],
        pooling=config["pooling"]
    )
elif config["type"] == "pt":
    # Load ViT processor and model
    processor = ViTImageProcessor.from_pretrained(config["model_name"])
    base_model = ViTModel.from_pretrained(config["model_name"])
else:
    raise ValueError("Unsupported model type")

# Define embeddings file path
embeddings_file = f'embeddings/dam_features_{selected_model_key}.pkl'

def preprocess_for_model(pil_image):
    target_size = config["target_size"]
    if config["type"] == "tf":
        # Preprocess for TensorFlow models like ResNet50
        image_array = keras_image.img_to_array(pil_image)
        img_resized = cv2.resize(image_array, target_size)
        x = keras_image.img_to_array(img_resized)
        x = np.expand_dims(x, axis=0)
        x = config["preprocess_func"](x)
        return x
    elif config["type"] == "pt":
        # Preprocess for ViT using its processor
        
        # Plot the original image
        # plt.imshow(pil_image)
        # plt.axis('off')
        # plt.show()
        
        inputs = processor(images=pil_image, return_tensors="pt")
        
        # Plot the processed image
        # plt.imshow((inputs.pixel_values[0].permute(1, 2, 0) + 1) / 2.0)
        # plt.axis('off')
        # plt.show()
        
        return inputs
    else:
        raise ValueError("Unsupported model type in preprocessing")

def extract_features(pil_image):
    if config["type"] == "tf":
        preprocessed = preprocess_for_model(pil_image)
        features = base_model.predict(preprocessed, verbose=0)
        return features.flatten()
    elif config["type"] == "pt":
        # For ViT, use torch.no_grad() for inference
        inputs = preprocess_for_model(pil_image)
        
        with torch.no_grad():
            outputs = base_model(**inputs)
        # Get last hidden state and perform global average pooling
        last_hidden_states = outputs.last_hidden_state  # shape: (1, sequence_length, hidden_size)
        # Exclude class token if present, then mean pool over tokens
        # For base ViT, first token is CLS, so we use all tokens except CLS for pooling
        token_embeddings = last_hidden_states[:, 1:, :]  # shape: (1, seq_len-1, hidden_size)
        pooled_embedding = token_embeddings.mean(dim=1)   # shape: (1, hidden_size)
        return pooled_embedding.cpu().numpy().flatten()
    else:
        raise ValueError("Unsupported model type in feature extraction")
    
# Assuming dam_df DataFrame is defined

if os.path.exists(embeddings_file):
    with open(embeddings_file, 'rb') as f:
        dam_features = pickle.load(f)
    print(f"Loaded precomputed DAM features from {embeddings_file}.")
else:
    dam_features = {}
    for idx, row in tqdm(dam_df.iterrows(), total=len(dam_df), desc="Extracting DAM features"):
        img_path = row['image_path']
        img = Image.open(img_path)
        if img is not None:
            # Convert BGR to RGB for consistency
            img = preprocess_image(img)
            if img is None:
                # Skip if no object found
                continue
            feat = extract_features(img)
            dam_features[img_path] = feat
    os.makedirs(os.path.dirname(embeddings_file), exist_ok=True)
    with open(embeddings_file, 'wb') as f:
        pickle.dump(dam_features, f)
    print(f"Extracted features for DAM images and saved to {embeddings_file}.")

In [None]:
import math

# Set the desired number of top matches to display
top_n = 10
max_height = 7

results = []

for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Processing test images"):
    test_path = row['image_path']
    
    # Load original image using PIL
    original_image = Image.open(test_path)
    
    # Preprocess and extract object from test image
    # Adjust this line to use your actual extraction method
    test_obj = preprocess_image(original_image)
    
    # Extract features for the test object
    test_feat = extract_features(test_obj)
    
    # Compute similarities with all DAM features
    similarities = []
    for dam_path, dam_feat in dam_features.items():
        sim = cosine_similarity([test_feat], [dam_feat])[0][0]
        similarities.append((dam_path, sim))
    
    # Sort and retrieve the top `top_n` matches
    top_matches = sorted(similarities, key=lambda x: x[1], reverse=True)[:top_n]
    
    # Store results
    results.append({
        'test_image': test_path,
        'top_matches': top_matches
    })
    
    # Display match information in console
    tqdm.write(f"Processed {test_path} -> Top {top_n} matches:")
    for match in top_matches:
        tqdm.write(f"\tMatch: {match[0]}, Similarity: {match[1]:.4f}")
    
    # Determine number of columns for grid specification
    ncols = min(max(2, top_n), 5)  # Ensure at least 2 columns for original and extracted images
    nrows = math.ceil(top_n / ncols) + 1  # Add 1 for the original and extracted images
    
    # Create subplots with 2 rows and `ncols` columns
    fig = plt.figure(constrained_layout=True, figsize=(4 * ncols, max_height * nrows))
    gs = fig.add_gridspec(nrows, ncols)
    
    # Subplot for original image (row 0, col 0)
    ax_orig = fig.add_subplot(gs[0, 0])
    ax_orig.imshow(original_image)
    ax_orig.set_title("Original Image")
    ax_orig.axis('off')
    
    # Subplot for extracted object (row 0, col 1)
    ax_ext = fig.add_subplot(gs[0, 1])
    ax_ext.imshow(test_obj)
    ax_ext.set_title("Extracted Object")
    ax_ext.axis('off')
    
    # Leave remaining cells in the first row empty if any
    for col in range(2, ncols):
        ax_empty = fig.add_subplot(gs[0, col])
        ax_empty.axis('off')
    
    # Display top `top_n` matches in the second row
    for i, (match_path, sim_score) in enumerate(top_matches):
        ax = fig.add_subplot(gs[i // ncols + 1, i % ncols])
        img_match = Image.open(match_path)
        ax.imshow(img_match)
        ax.set_title(f"Match {i+1}\nSim: {sim_score:.4f}")
        ax.axis('off')
    
    plt.show()