In [None]:
# Install required packages
%pip install torch torchvision scikit-learn tqdm

In [None]:
import os
import pickle
import json
import random
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from glob import glob

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Load precomputed DAM and 3D embeddings
BASE_CACHE_DIR = '.cache'
selected_model_key = "GoogleViTModel"  # adjust if needed
BACKGROUND_REMOVAL_METHOD = 'RMBG_2'
EMBEDDING_AGGREGRATION_METHOD = "mean"

# Load DAM embeddings (including 3D augmentations if available)
embeddings_file = os.path.join(BASE_CACHE_DIR, "embeddings", f'dam_features-{selected_model_key}-{BACKGROUND_REMOVAL_METHOD}.pkl')
embeddings_file_3d = os.path.join(BASE_CACHE_DIR, "embeddings", f'dam_features-{selected_model_key}-rembg-3d.pkl')

with open(embeddings_file, 'rb') as f:
    dam_features = pickle.load(f)

if os.path.exists(embeddings_file_3d):
    with open(embeddings_file_3d, 'rb') as f:
        dam_features_3d = pickle.load(f)
    # Merge original and 3D features
    dam_features.update(dam_features_3d)

# Function to aggregate embeddings if needed
def aggregate_embedding(embedding):
    if len(embedding.shape) > 2:
        # Assuming last dimension tokens, take mean over tokens
        embedding = embedding.squeeze()
        embedding = np.mean(embedding, axis=0)
    else:
        embedding = embedding.flatten()
    return embedding

# Process all embeddings: aggregate them for simplicity
for key in list(dam_features.keys()):
    dam_features[key] = aggregate_embedding(dam_features[key])

In [None]:
# ----------------------------------------------------------------------------
# Network Definition
# ----------------------------------------------------------------------------

class EmbeddingNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim=512):
        super(EmbeddingNetwork, self).__init__()
        # Simple fully connected network to learn an embedding space
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, x):
        return self.fc(x)

In [None]:
# ----------------------------------------------------------------------------
# Triplet Loss and Dataset Preparation
# ----------------------------------------------------------------------------

# We'll use PyTorch's built-in TripletMarginLoss
triplet_loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)

class TripletDataset(Dataset):
    def __init__(self, dam_features, anchor_positive_pairs, negative_samples):
        """
        dam_features: dict mapping image paths to embeddings.
        anchor_positive_pairs: list of tuples (anchor_path, positive_path).
        negative_samples: dict mapping anchor_path to a list of negative paths.
        """
        self.dam_features = dam_features
        self.anchor_positive_pairs = anchor_positive_pairs
        self.negative_samples = negative_samples
        
    def __len__(self):
        return len(self.anchor_positive_pairs)
    
    def __getitem__(self, idx):
        anchor_path, positive_path = self.anchor_positive_pairs[idx]
        
        # Sample one negative example for the anchor
        negatives = self.negative_samples[anchor_path]
        negative_path = random.choice(negatives) if negatives else None
        
        # Retrieve embeddings
        anchor_emb = self.dam_features[anchor_path]
        positive_emb = self.dam_features[positive_path]
        negative_emb = self.dam_features[negative_path] if negative_path else positive_emb  # fallback
        
        return (torch.tensor(anchor_emb, dtype=torch.float32),
                torch.tensor(positive_emb, dtype=torch.float32),
                torch.tensor(negative_emb, dtype=torch.float32))

In [None]:
# ----------------------------------------------------------------------------
# Create Triplet Data for Training
# ----------------------------------------------------------------------------

# Group 3D augmentation images by their original DAM product id
product_to_3d = {}
for path in dam_features.keys():
    if "-" not in os.path.basename(path):
        continue
    basename = os.path.basename(path)
    prod_code = basename.split("-")[0]
    product_to_3d.setdefault(prod_code, []).append(path)

# Create anchor-positive pairs: pair each original DAM image with one of its 3D variants
anchor_positive_pairs = []
for path in list(dam_features.keys()):
    basename = os.path.basename(path)
    if "-" in basename:
        # Skip 3D variants
        continue
    prod_code = basename.split(".")[0]
    # Filter original images (not 3D variants)
    if "-" not in basename or not basename.split("-")[-1].isdigit():
        variants = product_to_3d.get(prod_code, [])
        # print(variants)
        for var_path in variants:
            if var_path != path:
                anchor_positive_pairs.append((path, var_path))
                # break  # use only one positive per anchor for simplicity

# Create negative samples for each anchor: list of images from different products
negative_samples = {}
all_paths = list(dam_features.keys())
for anchor_path, _ in anchor_positive_pairs:
    anchor_prod = os.path.basename(anchor_path).split("-")[0]
    negatives = []
    # Sample negatives from different products
    for neg_path in all_paths:
        neg_prod = os.path.basename(neg_path).split("-")[0]
        if neg_prod != anchor_prod:
            negatives.append(neg_path)
    negative_samples[anchor_path] = negatives

In [None]:
# ----------------------------------------------------------------------------
# Initialize Network, Loss, and DataLoader for Triplet Training
# ----------------------------------------------------------------------------

sample_embedding = next(iter(dam_features.values()))
input_dim = sample_embedding.shape[0]

embedding_net = EmbeddingNetwork(input_dim=input_dim).to(device)
optimizer = torch.optim.Adam(embedding_net.parameters(), lr=1e-4)

# Create dataset and dataloader for triplets
triplet_dataset = TripletDataset(dam_features, anchor_positive_pairs, negative_samples)
triplet_dataloader = DataLoader(triplet_dataset, shuffle=True, batch_size=32)

In [None]:
import matplotlib.pyplot as plt

# ----------------------------------------------------------------------------
# Triplet Training Loop with Loss Tracking
# ----------------------------------------------------------------------------

num_epochs = 100
epoch_losses = []  # List to store average loss per epoch

avg_loss = np.inf

for epoch in range(num_epochs):
    embedding_net.train()
    running_loss = 0.0
    for batch in tqdm(triplet_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}"):
        anchor, positive, negative = batch
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)

        optimizer.zero_grad()

        # Pass anchor, positive, and negative through the network
        anchor_out = embedding_net(anchor)
        positive_out = embedding_net(positive)
        negative_out = embedding_net(negative)

        loss = triplet_loss_fn(anchor_out, positive_out, negative_out)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    avg_loss = running_loss / len(triplet_dataloader)
    epoch_losses.append(avg_loss)
    # print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

# ----------------------------------------------------------------------------
# Plotting the Loss Curve
# ----------------------------------------------------------------------------

plt.figure(figsize=(8, 5))
plt.plot(range(1, num_epochs+1), epoch_losses, marker='o', label='Training Loss')
plt.title('Triplet Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.xticks(range(1, num_epochs+1))
plt.grid(True)
plt.legend()
plt.show()

In [None]:
# ----------------------------------------------------------------------------
# Evaluation on Label Dataset with Precomputation and Visualization
# ----------------------------------------------------------------------------

import matplotlib.pyplot as plt
from PIL import Image
from glob import glob
import pandas as pd
import torch
from tqdm import tqdm
import utils.preprocessing
from utils.preprocessing import preprocess_image
from utils.models.google_vit_model import GoogleViTModel

model = embedding_net

# Paths to directories
dam_dir = 'data/DAM'
test_dir = 'data/test_image_headmind'
extensions = ['*.jpg', '*.jpeg', '*.png']

# Get list of image file paths for DAM and Test
dam_images = glob(os.path.join(dam_dir, '*.jpeg'))
dam_images.sort()

test_images = []
for ext in extensions:
    pattern = os.path.join(test_dir, ext)
    test_images.extend(glob(pattern))
test_images.sort()

# Create DataFrames
dam_df = pd.DataFrame({'image_path': dam_images})
test_df = pd.DataFrame({'image_path': test_images})

# Load labels CSV
labels_df = pd.read_csv('labels/handmade_test_labels.csv')
labels_dict = {}
for _, row in labels_df.iterrows():
    image_name = row['image'].strip()
    references = [ref.strip() for ref in str(row['reference']).split('/') if ref.strip() and ref.strip() != '?']
    labels_dict[image_name] = references

# Define a function to predict top matches using precomputed Siamese outputs
def find_best_match_with_precomputed(test_output, dam_outputs, top_n=1):
    similarities = []
    # Use Euclidean distance in the learned space for comparison
    for dam_path, dam_out in dam_outputs.items():
        distance = torch.norm(test_output - dam_out).item()
        similarities.append((dam_path, distance))
    # Lower distance implies higher similarity
    similarities.sort(key=lambda x: x[1])
    return similarities[:top_n]

# Initialize VIT model for extracting test embeddings
vit_model = GoogleViTModel()
vit_model_device = device  # Use same device

# Precompute Siamese network outputs for all DAM images
model.eval()
dam_outputs = {}
with torch.no_grad():
    for path, emb in dam_features.items():
        # Convert each DAM embedding to tensor and pass through the Siamese network's first branch
        emb_tensor = torch.tensor(emb, dtype=torch.float32).unsqueeze(0).to(device)
        dam_out = model.forward(emb_tensor)  # Use forward_one to avoid redundant computation
        dam_outputs[path] = dam_out.squeeze(0)  # remove batch dimension

# Precompute Siamese network outputs for each test image
test_outputs = {}  # store mapping from test image path to its Siamese output
for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Precomputing test outputs"):
    test_path = row['image_path']
    test_obj = preprocess_image(test_path, background_removal="RMBG_2")
    if test_obj is None:
        continue
    with torch.no_grad():
        # Extract VIT embedding and aggregate
        test_feat = vit_model.extract_features(test_obj)
    test_emb = aggregate_embedding(test_feat)
    # Pass through Siamese network's branch
    test_tensor = torch.tensor(test_emb, dtype=torch.float32).unsqueeze(0).to(device)
    with torch.no_grad():
        test_out = model.forward(test_tensor)
    test_outputs[test_path] = test_out.squeeze(0)

correct_top1 = 0
total = 0

# Now evaluate using the precomputed outputs
for test_path, test_out in tqdm(test_outputs.items(), desc="Evaluating"):
    # Find the best match using precomputed outputs
    top_match = find_best_match_with_precomputed(test_out, dam_outputs, top_n=1)
    predicted_path, distance = top_match[0]
    predicted_code = os.path.basename(predicted_path).split('.')[0].split('-')[0]
    
    # Load original test image for plotting
    test_img = Image.open(test_path)
    
    # Get true references for current test image
    test_image_name = os.path.basename(test_path)
    true_references = labels_dict.get(test_image_name, [])
    
    # Attempt to locate a true match image path from dam_features using true references
    true_match_path = None
    for ref in true_references:
        for path in dam_features.keys():
            if os.path.basename(path).startswith(ref):
                true_match_path = path
                break
        if true_match_path:
            break

    if predicted_code in true_references:
        correct_top1 += 1
    total += 1
    
    # Plot test image, predicted and true match images side by side
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Display test image
    axes[0].imshow(test_img)
    axes[0].set_title("Test Image")
    axes[0].axis('off')
    
    # Display predicted image
    pred_img = Image.open(predicted_path)
    axes[1].imshow(pred_img)
    axes[1].set_title(f"Predicted Match\n{os.path.basename(predicted_path)}")
    axes[1].axis('off')
    
    # Display true match image if found
    if true_match_path:
        true_img = Image.open(true_match_path)
        axes[2].imshow(true_img)
        axes[2].set_title(f"True Match\n{os.path.basename(true_match_path)}")
    else:
        axes[2].text(0.5, 0.5, "True match not found",
                     horizontalalignment='center', verticalalignment='center', fontsize=12)
        axes[2].set_title("True Match")
    axes[2].axis('off')
    
    plt.show()

accuracy_top1 = correct_top1 / total if total > 0 else 0
print(f"Top-1 Accuracy on test set: {accuracy_top1:.2%} ({correct_top1}/{total})")

# ----------------------------------------------------------------------------
# Save Model and Benchmark Results
# ----------------------------------------------------------------------------

# Save trained Siamese model
torch.save(model.state_dict(), ".cache/models/siamese_model.pth")

# Save benchmark results
benchmark = {
    "top_1_accuracy": accuracy_top1
}
os.makedirs("benchmarks", exist_ok=True)
with open(os.path.join("benchmarks", "siamese_benchmark.json"), 'w') as f:
    json.dump(benchmark, f)