In [None]:
%pip install pandas scikit-learn Pillow numpy tqdm rembg matplotlib onnxruntime transformers torch torchvision torchaudio
%pip install kornia timm # RMBG-2.0

In [None]:
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
from tqdm import tqdm
import pandas as pd
from glob import glob
import PIL
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation

In [None]:
# Paths to directories
dam_dir = 'data/DAM'
test_dir = 'data/test_image_headmind'
extensions = ['*.jpg', '*.jpeg', '*.png']

# Get list of image file paths for DAM and Test
dam_images = glob(os.path.join(dam_dir, '*.jpeg'))
dam_images.sort()

test_images = []
for ext in extensions:
    pattern = os.path.join(test_dir, ext)
    test_images.extend(glob(pattern))
test_images.sort()

# Create DataFrames
dam_df = pd.DataFrame({'image_path': dam_images})
test_df = pd.DataFrame({'image_path': test_images})

print("DAM DataFrame:")
print(dam_df.head())

print("\nTest DataFrame:")
print(test_df.head())


In [None]:
from utils.preprocessing import preprocess_image
import utils.preprocessing
import importlib
importlib.reload(utils.preprocessing)

import matplotlib.pyplot as plt
from PIL import Image, ImageOps as PIL_ImageOps

BASE_CACHE_DIR = '.cache'

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
print(f"Using device: {device}")

sample_test_path = test_df['image_path'].iloc[60]
print(f"Sample test path: {sample_test_path}")
extracted_object = preprocess_image(sample_test_path, background_removal="RMBG_2")

if extracted_object is None:
    print("No object found in the image at path:", sample_test_path)

# Display the original and extracted object
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(PIL.ImageOps.exif_transpose(Image.open(sample_test_path)))
plt.title("Original Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(extracted_object)
plt.title("Extracted Object")
plt.axis('off')
plt.show()

# for i in tqdm(range(len(dam_df))):
#     sample_test_path = dam_df['image_path'].iloc[i]
#     extracted_object = preprocess_image(sample_test_path)

#     if extracted_object is None:
#         print("No object found in the image at path:", sample_test_path)
#         continue


In [None]:
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm
import os
import tensorflow as tf
from tensorflow.keras.preprocessing import image as keras_image
from sklearn.metrics.pairwise import cosine_similarity
import importlib
import json

import utils.models
import utils.models.nomic_embed_vision_model
importlib.reload(utils.models.base_model)
importlib.reload(utils.models.dinov2_model)
importlib.reload(utils.models.facebook_vitmsn_model)
importlib.reload(utils.models.google_vit_model)
importlib.reload(utils.models.microsoft_resnet_model)
importlib.reload(utils.models.openai_clip_model)
importlib.reload(utils.models.fashion_clip_model)
importlib.reload(utils.models.nomic_embed_vision_model)
importlib.reload(utils.models)
import utils.models
from utils.models import DinoV2Model, FacebookViTMSNModel, GoogleViTModel, MicrosoftResNetModel, OpenAIClipModel, FashionCLIPModel, NomicEmbedVisionModel


BACKGROUND_REMOVAL_METHOD = 'RMBG_2' # 'rembg', 'RMBG_2'
AUGMENT_WITH_3D_MODEL = False

SELECTED_MODEL = GoogleViTModel
EMBEDDING_AGGREGRATION_METHOD = "mean" # None, 'mean'
COSINE_SIMILARITY = True # False = Euclidean distance

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow warnings

# Select the desired model by key
selected_model_key = SELECTED_MODEL.__name__
model = SELECTED_MODEL()

# Define embeddings file path
embeddings_file = os.path.join(BASE_CACHE_DIR, "embeddings", f'dam_features-{selected_model_key}-{BACKGROUND_REMOVAL_METHOD}.pkl')

if "dam_features_2d" in locals():
    del dam_features_2d
if "dam_features_3d" in locals():
    del dam_features_3d

# Aggregate the embedding of an image by averaging the embeddings of its patches
def aggregate_embedding(embedding):
    if len(embedding.shape) >= 2:
        embedding = embedding.squeeze()
        if EMBEDDING_AGGREGRATION_METHOD == 'mean':
            embedding = np.mean(embedding, axis=0).reshape(1, -1)
    else:
        embedding = embedding.reshape(1, -1)
    return embedding

if os.path.exists(embeddings_file):
    if "dam_features_2d" not in locals():
        with open(embeddings_file, 'rb') as f:
            dam_features_2d = pickle.load(f)
    print(f"Loaded precomputed DAM features from {embeddings_file}.")
else:
    dam_features_2d = {}
    batch_size_2d = 32  # Adjust batch size according to your memory and performance requirements
    images_batch_2d = []
    paths_batch_2d = []

    for idx, row in tqdm(dam_df.iterrows(), total=len(dam_df), desc="Extracting DAM features (2D)"):
        img_path = row['image_path']
        img = preprocess_image(img_path, BACKGROUND_REMOVAL_METHOD)
        if img is None:
            # Skip if no object found
            continue

        images_batch_2d.append(img)
        paths_batch_2d.append(img_path)

        # If batch size is reached, process the batch
        if len(images_batch_2d) == batch_size_2d:
            # Extract features for the current batch
            feats_batch = model.extract_features(images_batch_2d)
            # Store features with their corresponding paths
            for path, feat in zip(paths_batch_2d, feats_batch):
                dam_features_2d[path] = feat

            # Clear batches for the next set
            images_batch_2d = []
            paths_batch_2d = []

    # Process any remaining images in the last batch
    if images_batch_2d:
        feats_batch = model.extract_features(images_batch_2d)
        for path, feat in zip(paths_batch_2d, feats_batch):
            dam_features_2d[path] = feat

    # Save the extracted 2D features to a pickle file
    os.makedirs(os.path.dirname(embeddings_file), exist_ok=True)
    with open(embeddings_file, 'wb') as f:
        pickle.dump(dam_features_2d, f)
    print(f"Extracted features for DAM images (2D) and saved to {embeddings_file}.")

feature_selection_coefficients = np.ones(next(iter(dam_features_2d.values())).shape[-1], dtype=np.float32)
print(f"Feature selection coefficients shape: {feature_selection_coefficients.shape}")

# Augment the data with the 3D model features

if AUGMENT_WITH_3D_MODEL:
    embeddings_file_3d = os.path.join(BASE_CACHE_DIR, "embeddings", f'dam_features-{selected_model_key}-rembg-3d.pkl')
    
    if os.path.exists(embeddings_file_3d):
        if "dam_features_3d" not in locals():
            with open(embeddings_file_3d, 'rb') as f:
                dam_features_3d = pickle.load(f)
        print(f"Loaded precomputed DAM features (3D) from {embeddings_file_3d}.")
    else:
        dam_features_3d = {}
        batch_size = 32  # Adjust batch size according to your memory and performance requirements
        images_batch = []
        paths_batch = []
        
        for idx, row in tqdm(dam_df.iterrows(), total=len(dam_df), desc="Extracting DAM features (3D)"):
            img_path = row['image_path']
            dam_id = os.path.basename(img_path).split('.')[0]
            
            for i in range(1, 9):
                img_path_3d = f"{BASE_CACHE_DIR}/TRELLIS/{dam_id}/{dam_id}-{i}.png"
                if not os.path.exists(img_path_3d):
                    continue
                img = preprocess_image(img_path_3d, "rembg")
                if img is None:
                    # Skip if no object found
                    continue
                
                # Single processing
                # feat = model.extract_features(img)
                # dam_features_3d[img_path_3d] = feat
                
                # Batch processing
                images_batch.append(img)
                paths_batch.append(img_path_3d)

                # If batch size is reached, process the batch
                if len(images_batch) == batch_size:
                    # Extract features for the current batch
                    feats_batch = model.extract_features(images_batch)
                    # Store features with their corresponding paths
                    for path, feat in zip(paths_batch, feats_batch):
                        dam_features_3d[path] = feat

                    # Clear batches for the next set
                    images_batch = []
                    paths_batch = []
        
        # Process any remaining images in the last batch
        if images_batch:
            feats_batch = model.extract_features(images_batch)
            for path, feat in zip(paths_batch, feats_batch):
                dam_features_3d[path] = feat
        
        os.makedirs(os.path.dirname(embeddings_file_3d), exist_ok=True)
        with open(embeddings_file_3d, 'wb') as f:
            pickle.dump(dam_features_3d, f)
        print(f"Extracted features for DAM images (3D) and saved to {embeddings_file_3d}.")
        
    merging_coefficients = np.ones(next(iter(dam_features_3d.values())).shape[-1], dtype=np.float32) * 0.5

    # Map 3D features to their corresponding 2D features
    map_3d_feature_to_dam_feature = {}
    combined_features_list = []
    dam_paths = []

    for path_3d, feat_3d in dam_features_3d.items():
        # Extract the DAM ID from the 3D path
        dam_id = os.path.basename(os.path.dirname(path_3d))
        # Construct the corresponding original DAM image path
        corresponding_path = os.path.join("data", "DAM", f"{dam_id}.jpeg")
        
        if corresponding_path in dam_features_2d:
            map_3d_feature_to_dam_feature[path_3d] = corresponding_path
            feat_2d = dam_features_2d[corresponding_path]
            
            # Correct Feature Aggregation:
            # Multiply 3D features by merging_coefficients and 2D features by (1 - merging_coefficients)
            combined_feat = (aggregate_embedding(feat_3d) * merging_coefficients) + (aggregate_embedding(feat_2d) * (1 - merging_coefficients))
            
            # Apply feature selection coefficients
            combined_feat = combined_feat * feature_selection_coefficients  # Ensure broadcasting is handled correctly
            
            combined_features_list.append(combined_feat)
            dam_paths.append(path_3d)
        else:
            print(f"No corresponding DAM feature found for 3D image {path_3d}.")
else:
    # Initialize lists to hold aggregated features and their corresponding paths
    combined_features_list = []
    dam_paths = []
    
    print(list(dam_features_2d.values())[0].shape)
    
    # Iterate over each 2D feature in dam_features_2d
    for path, feat in tqdm(dam_features_2d.items(), desc="Building DAM feature matrix (2D)"):
        # Aggregate the embedding (e.g., take the mean if aggregation method is 'mean')
        aggregated_feat = aggregate_embedding(feat) * feature_selection_coefficients  # Apply feature selection coefficients
        
        # Append the aggregated feature and its path to the respective lists
        combined_features_list.append(aggregated_feat)
        dam_paths.append(path)
    
# Convert the list of aggregated features into a NumPy array
dam_features_matrix = np.vstack(combined_features_list)  # Shape: (num_dam, feature_dim)
dam_feature_paths = np.array(dam_paths)

print(f"Built DAM feature matrix from 2D features with shape {dam_features_matrix.shape}.")

# Normalize features if using cosine similarity
if COSINE_SIMILARITY:
    # Compute the L2 norms of each feature vector
    norms = np.linalg.norm(dam_features_matrix, axis=1, keepdims=True)
    # Avoid division by zero by setting zero norms to one
    norms[norms == 0] = 1
    # Normalize the feature matrix
    dam_features_matrix = dam_features_matrix / norms
    print("Normalized DAM feature matrix for cosine similarity.")

# Optionally, convert to float32 for memory and computational efficiency
dam_features_matrix = dam_features_matrix.astype(np.float32)

# dam_features = dam_features_2d
# dam_features = dam_features_2d | dam_features_3d
# dam_features = dam_features_3d

In [None]:
import pandas as pd

# Load the labels CSV into a DataFrame
labels_df = pd.read_csv('labels/handmade_test_labels.csv')

# Create a dictionary mapping each test image filename to a list of reference labels
labels_dict = {}
for _, row in labels_df.iterrows():
    image_name = row['image'].strip()
    # Split the reference column by '/' to handle multiple labels, remove empty strings and spaces
    references = [ref.strip() for ref in str(row['reference']).split('/') if ref.strip() and ref.strip() != '?']
    labels_dict[image_name] = references

# For debugging, you can print a sample of the dictionary
print(dict(list(labels_dict.items())[:5]))

In [None]:
import math
import time

# Initialize accuracy counters
total_queries = 0
correct_top1 = 0
correct_top3 = 0
correct_top5 = 0

# This dictionary will store the index of the correct match in each query if found
correct_match_indices = {}

# Set the desired number of top matches to display
top_n = 10
max_height = 5

results = []

# Start Processing Test Images
for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Processing test images"):
    test_path = row['image_path']
    
    total_queries += 1
    
    # Preprocess Test Image
    t = time.time()
    test_obj = preprocess_image(test_path, BACKGROUND_REMOVAL_METHOD)  # Adjust if needed
    if test_obj is None:
        print(f"Preprocessing failed for {test_path}. Skipping.")
        continue
    preprocessing_time = time.time() - t
    print(f"Preprocessing time: {preprocessing_time:.4f} seconds")
    
    # Extract Features
    t = time.time()
    test_feat = model.extract_features(test_obj)  # Adjust if batch processing is needed
    feature_extraction_time = time.time() - t
    print(f"Feature extraction time: {feature_extraction_time:.4f} seconds")
    
    # Aggregate Features
    t = time.time()
    test_feat = aggregate_embedding(test_feat)  # Adjust as per your aggregation method
    
    if AUGMENT_WITH_3D_MODEL:
        # If test features have both 3D and 2D, apply coefficients accordingly
        # Here, assuming you have separate test_feat_3d and test_feat_2d
        # If not, adjust based on your actual feature extraction process
        
        # Example:
        # test_feat_combined = (test_feat_3d * merging_coefficients) + (test_feat_2d * (1 - merging_coefficients))
        
        # If only 2D features are available for test images:
        test_feat_combined = test_feat * feature_selection_coefficients  # Only 2D features
    else:
        test_feat_combined = test_feat * feature_selection_coefficients
    
    aggregation_time = time.time() - t
    print(f"Feature aggregation time: {aggregation_time:.4f} seconds")
    
    # Normalize Test Feature if using cosine similarity
    if COSINE_SIMILARITY:
        norm = np.linalg.norm(test_feat_combined)
        if norm == 0:
            norm = 1
        test_feat_combined = test_feat_combined / norm
    
    # Convert to NumPy array
    test_feat_combined = test_feat_combined.astype(np.float32).squeeze()
    
    # Compute Similarities
    t = time.time()
    if COSINE_SIMILARITY:
        # Compute cosine similarity using dot product since features are normalized
        similarities = np.dot(dam_features_matrix, test_feat_combined)
    else:
        # Compute Euclidean distances
        differences = dam_features_matrix - test_feat_combined
        similarities = -np.linalg.norm(differences, axis=1)  # Negative for consistency with similarity
    similarity_time = time.time() - t
    print(f"Similarity computation time: {similarity_time:.4f} seconds")
    
    # Retrieve Top-N Matches
    t = time.time()
    if top_n < len(similarities):
        top_indices = np.argpartition(-similarities, top_n)[:top_n]
        top_similarities = similarities[top_indices]
        sorted_top_indices = top_indices[np.argsort(-top_similarities)]
    else:
        sorted_top_indices = np.argsort(-similarities)
    sorted_matches = list(zip(dam_feature_paths[sorted_top_indices], similarities[sorted_top_indices]))
    sorting_time = time.time() - t
    print(f"Sorting time: {sorting_time:.4f} seconds")
    
    # Store Results
    results.append({
        'test_image': test_path,
        'top_matches': sorted_matches[:top_n]
    })
    
    # Retrieve True Reference Labels
    test_image_name = os.path.basename(test_path)
    true_references = labels_dict.get(test_image_name, [])
    
    # Check Positions of Correct Matches
    found_indices = []
    top_n_matches = []
    code_duplicates = set()
    for i, (match_path, sim_score) in enumerate(sorted_matches):
        predicted_code = os.path.splitext(os.path.basename(match_path))[0]
        if "-" in predicted_code:
            predicted_code = predicted_code.split("-")[0]
        
        if predicted_code in code_duplicates:
            continue
        
        if predicted_code in true_references:
            found_indices.append(len(top_n_matches))
        
        code_duplicates.add(predicted_code)
        top_n_matches.append((match_path, sim_score))
        if len(top_n_matches) >= top_n:
            break

    # Update Accuracy Counters
    if found_indices:
        correct_index = min(found_indices)
        correct_match_indices[test_image_name] = correct_index
        if correct_index < 1:
            correct_top1 += 1
        if correct_index < 3:
            correct_top3 += 1
        if correct_index < 5:
            correct_top5 += 1
    
    # Visualization (Optional: Can be skipped for performance)
    # Determine number of columns for grid specification
    ncols = min(max(2, top_n), 5)  # Ensure at least 2 columns
    nrows = math.ceil(top_n / ncols) + 1  # Add 1 for original and extracted images
    
    # Create subplots
    fig = plt.figure(constrained_layout=True, figsize=(4 * ncols, max_height * nrows))
    gs = fig.add_gridspec(nrows, ncols)
    
    # Subplot for Original Image
    ax_orig = fig.add_subplot(gs[0, 0])
    ax_orig.imshow(PIL_ImageOps.exif_transpose(Image.open(test_path)))
    ax_orig.set_title("Original Image")
    ax_orig.axis('off')
    
    # Subplot for Extracted Object
    ax_ext = fig.add_subplot(gs[0, 1])
    ax_ext.imshow(test_obj)
    ax_ext.set_title("Extracted Object")
    ax_ext.axis('off')
    
    # Leave remaining cells in the first row empty
    for col in range(2, ncols):
        ax_empty = fig.add_subplot(gs[0, col])
        ax_empty.axis('off')
    
    # Display Top-N Matches
    for i, (match_path, sim_score) in enumerate(top_n_matches[:top_n]):
        ax = fig.add_subplot(gs[i // ncols + 1, i % ncols])
        img_match = Image.open(match_path)
        img_match = PIL_ImageOps.exif_transpose(img_match)
        ax.imshow(img_match)
        ax.set_title(f"Match {i+1}\nSim: {sim_score:.4f}")
        ax.axis('off')
    
    plt.show()
    
    # Compute and Display Overall Accuracy Metrics
    if total_queries > 0:
        accuracy_top1 = correct_top1 / total_queries
        accuracy_top3 = correct_top3 / total_queries
        accuracy_top5 = correct_top5 / total_queries

        print(f"Total queries processed: {total_queries}")
        print(f"Top-1 Accuracy: {accuracy_top1:.2%} ({correct_top1} correct)")
        print(f"Top-3 Accuracy: {accuracy_top3:.2%} ({correct_top3} correct)")
        print(f"Top-5 Accuracy: {accuracy_top5:.2%} ({correct_top5} correct)")
    else:
        print("No queries were processed.")
    
    # Optionally print the indices where correct matches were found for each image
    if test_image_name in correct_match_indices:
        print(f"Correct match found at top-{correct_match_indices[test_image_name]+1}")
        
    # Display Match Information in Console
    tqdm.write(f"Processed {test_path} -> Top {top_n} matches:")
    for match in top_n_matches[:top_n]:
        tqdm.write(f"\tMatch: {match[0]}, Similarity: {match[1]:.4f}")

# After all test images are processed, save the benchmark
filename = f"{type(model).__name__}-{BACKGROUND_REMOVAL_METHOD}-{EMBEDDING_AGGREGRATION_METHOD}-{'euclidean' if not COSINE_SIMILARITY else 'cosine'}"
if AUGMENT_WITH_3D_MODEL:
    filename += "-3d"
filename += ".json"

benchmark = {
    "top_1_accuracy": accuracy_top1 if total_queries > 0 else 0.0,
    "top_3_accuracy": accuracy_top3 if total_queries > 0 else 0.0,
    "top_5_accuracy": accuracy_top5 if total_queries > 0 else 0.0
}

os.makedirs("benchmarks", exist_ok=True)
with open(os.path.join("benchmarks", filename), 'w') as f:
    json.dump(benchmark, f)

print(f"Benchmark results saved to benchmarks/{filename}")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Assuming 'device' is predefined, along with labels_dict and other necessary functions/variables (preprocess_image, model, aggregate_embedding, etc.)

# Prepare the first 20 images for gradient descent optimization.
shuffled_indices = list(range(len(test_df)))
random.shuffle(shuffled_indices)
train_indices = shuffled_indices[:20]
train_samples = test_df.iloc[train_indices]

# Initialize coefficients as torch parameters and move them to the specified device.
merging_coefficients = torch.nn.Parameter(torch.ones(768, device=device) * 0.5)  
feature_selection_coefficients = torch.nn.Parameter(torch.ones(768, device=device) * 0.5)

# Define an optimizer for both coefficient vectors.
optimizer = optim.Adam([merging_coefficients, feature_selection_coefficients], lr=0.01)

# Define a cosine similarity function and move it to the device.
cos = nn.CosineSimilarity(dim=1).to(device)

num_epochs = 100
losses = []

all_dam_paths = list(dam_features_3d.keys())

# Check if the variable is already defined;
if not 'cached_test_features' in locals():
    cached_test_features = {}

# Cache all test features for faster processing
for _, row in test_df.iterrows():
    if row['image_path'] in cached_test_features:
        continue
    test_path = row['image_path']
    test_obj = preprocess_image(test_path)
    test_feat = model.extract_features(test_obj)
    test_feat = aggregate_embedding(test_feat)  # shape (1, 197, 768)
    cached_test_features[test_path] = test_feat

for epoch in range(num_epochs):
    total_loss = 0.0
    optimizer.zero_grad()
    
    for _, row in train_samples.iterrows():
        test_path = row['image_path']
        
        test_feat = cached_test_features[test_path]
        
        # Retrieve references for the current test image
        test_image_name = os.path.basename(test_path)
        references = labels_dict.get(test_image_name, [])
        if not references:
            continue
        
        # Choose one correct reference for this iteration
        correct_reference = random.choice(references)
        correct_dam_3d_path = f".cache/TRELLIS/{correct_reference}/{correct_reference}-{random.randint(1, 8)}.png"
        
        if correct_dam_3d_path not in dam_features_3d:
            continue
        
        # Retrieve and process correct DAM features
        dam_feat_3d = dam_features_3d[correct_dam_3d_path]
        dam_feat_2d = dam_features_2d[map_3d_feature_to_dam_feature[correct_dam_3d_path]]
        dam_feat_3d = aggregate_embedding(dam_feat_3d)
        dam_feat_2d = aggregate_embedding(dam_feat_2d)
        
        # Create tensors and move them to device
        dam_feat_3d_tensor = torch.tensor(dam_feat_3d, dtype=torch.float32, device=device)
        dam_feat_2d_tensor = torch.tensor(dam_feat_2d, dtype=torch.float32, device=device)
        
        # Expand merging coefficients for broadcasting (already on device)
        merge_coeff_expanded = merging_coefficients.view(1, 768)
        
        # Combine features using merging coefficients
        combined_correct = dam_feat_3d_tensor * merge_coeff_expanded + dam_feat_2d_tensor * (1 - merge_coeff_expanded)
        
        # Convert test feature to torch tensor and move to device
        test_feat_tensor = torch.tensor(test_feat, dtype=torch.float32, device=device)
        
        # Expand feature selection coefficients for broadcasting (already on device)
        feat_sel_expanded = feature_selection_coefficients.view(1, 768)
        
        # Apply feature selection to test and combined embeddings
        weighted_test = test_feat_tensor * feat_sel_expanded
        weighted_correct = combined_correct * feat_sel_expanded
        
        # Compute similarity with correct DAM using weighted features
        similarity_correct = cos(weighted_test, weighted_correct).mean()
        
        positive_distance = 1 - similarity_correct # Cosine distance between the anchor (weighted_correct) and positive sample (weighted_test)
        
        # Sample a negative DAM that is not in the reference list
        negative_dam_path = None
        attempts = 0
        while attempts < 10:
            candidate = random.choice(list(train_samples['image_path']))
            if candidate != test_path:
                negative_dam_path = candidate
                break
            attempts += 1
        
        # If a valid negative sample is found
        if negative_dam_path and negative_dam_path in cached_test_features:
            neg_dam_feat = cached_test_features[negative_dam_path]
            
            neg_dam_feat_tensor = torch.tensor(neg_dam_feat, dtype=torch.float32, device=device)
            
            weighted_negative = neg_dam_feat_tensor * feat_sel_expanded
            
            similarity_negative = cos(weighted_correct, weighted_negative).mean()
            negative_distance = 1 - similarity_negative # Cosine distance between the anchor (weighted_correct) and negative sample (weighted_negative)
        else:
            print("No valid negative sample found.")
        
        loss  = torch.clamp(positive_distance - negative_distance + 0.2, min=0.0)
        
        total_loss += loss.item()
        loss.backward()
    
    optimizer.step()
    
    with torch.no_grad():
        merging_coefficients.clamp_(0.0, 1.0)
        feature_selection_coefficients.clamp_(0.0, 1.0)
    
    avg_loss = total_loss / len(train_samples)
    losses.append(avg_loss)
    
    clear_output(wait=True)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    plt.figure(figsize=(8, 4))
    plt.plot(range(1, epoch+2), losses, marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Average Loss')
    plt.title('Training Loss Over Epochs')
    plt.grid(True)
    plt.show()


merging_coefficients = merging_coefficients.detach().cpu().numpy()
feature_selection_coefficients = feature_selection_coefficients.detach().cpu().numpy()

# After training, coefficients are on device; move them to CPU for printing if needed.
print("Optimized merging coefficients:", merging_coefficients)
print("Optimized feature selection coefficients:", feature_selection_coefficients)

In [None]:
# Save coefficients to a file with pickle
merging_coefficients_file = os.path.join(BASE_CACHE_DIR, "coefficients", f"merging_coefficients-{selected_model_key}-{BACKGROUND_REMOVAL_METHOD}.pkl")
os.makedirs(os.path.dirname(merging_coefficients_file), exist_ok=True)
with open(merging_coefficients_file, 'wb') as f:
    pickle.dump(merging_coefficients.cpu().detach().numpy(), f)

feature_selection_coefficients_file = os.path.join(BASE_CACHE_DIR, "coefficients", f"feature_selection_coefficients-{selected_model_key}-{BACKGROUND_REMOVAL_METHOD}.pkl")
os.makedirs(os.path.dirname(feature_selection_coefficients_file), exist_ok=True)
with open(feature_selection_coefficients_file, 'wb') as f:
    pickle.dump(feature_selection_coefficients.cpu().detach().numpy(), f)
    