# Jewelry Image Similarity with Siamese Networks

This notebook demonstrates how to build an image similarity model for finding visually similar jewelry items.

## Overview
- **Task**: Image similarity / retrieval
- **Approach**: Siamese network with contrastive loss
- **Use Case**: Find similar jewelry items based on visual appearance
- **Output**: Vector embeddings that can be compared using distance metrics

## 1. Setup and Imports

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import random
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.manifold import TSNE
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {torchvision.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print("✓ All imports successful!")

## 2. Data Loading and Exploration

In [None]:
# Set up data paths
DATA_DIR = Path("/project/data/raw_jewelry")
MODEL_DIR = Path("/project/models")
MODEL_DIR.mkdir(exist_ok=True)

# Explore the dataset structure
categories = sorted([d.name for d in DATA_DIR.iterdir() if d.is_dir()])
print(f"Number of categories: {len(categories)}")
print(f"Categories: {categories}\n")

# Collect all images with their categories
all_images = []
for category in categories:
    category_path = DATA_DIR / category
    image_files = list(category_path.glob("*.jpg")) + list(category_path.glob("*.png"))
    for img_path in image_files:
        all_images.append({'path': img_path, 'category': category})
    print(f"{category}: {len(image_files)} images")

print(f"\nTotal images: {len(all_images)}")

# Create lookup by category for triplet sampling
images_by_category = defaultdict(list)
for img_info in all_images:
    images_by_category[img_info['category']].append(img_info['path'])

## 3. Configuration and Transforms

In [None]:
# Configuration
IMG_SIZE = 224
BATCH_SIZE = 16  # Smaller batch for triplet learning
EPOCHS = 30
EMBEDDING_DIM = 128  # Size of embedding vector
LEARNING_RATE = 0.0001
MARGIN = 1.0  # Margin for triplet loss
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transforms (same for all images)
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print(f"Device: {DEVICE}")
print(f"Image size: {IMG_SIZE}x{IMG_SIZE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Embedding dimension: {EMBEDDING_DIM}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Triplet margin: {MARGIN}")

## 4. Triplet Dataset

For similarity learning, we use triplets: (anchor, positive, negative)
- **Anchor**: Reference image
- **Positive**: Different image of same category (similar)
- **Negative**: Image from different category (dissimilar)

In [None]:
class TripletDataset(Dataset):
    def __init__(self, images_by_category, categories, transform=None):
        self.images_by_category = images_by_category
        self.categories = categories
        self.transform = transform
        
        # Flatten to get all image paths
        self.all_images = []
        for cat in categories:
            self.all_images.extend([(path, cat) for path in images_by_category[cat]])
    
    def __len__(self):
        return len(self.all_images)
    
    def __getitem__(self, idx):
        # Get anchor
        anchor_path, anchor_category = self.all_images[idx]
        
        # Get positive (same category, different image)
        positive_candidates = [p for p in self.images_by_category[anchor_category] if p != anchor_path]
        if len(positive_candidates) == 0:
            positive_path = anchor_path  # Fallback if only one image in category
        else:
            positive_path = random.choice(positive_candidates)
        
        # Get negative (different category)
        negative_category = random.choice([c for c in self.categories if c != anchor_category])
        negative_path = random.choice(self.images_by_category[negative_category])
        
        # Load and transform images
        anchor = Image.open(anchor_path).convert('RGB')
        positive = Image.open(positive_path).convert('RGB')
        negative = Image.open(negative_path).convert('RGB')
        
        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)
        
        return anchor, positive, negative

# Split data
train_categories = categories  # Using all categories for training
train_dataset = TripletDataset(images_by_category, train_categories, transform=transform)
val_dataset = TripletDataset(images_by_category, train_categories, transform=val_transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 5. Siamese Network Architecture

Uses a shared CNN backbone to create embeddings for all three images in a triplet.

In [None]:
class EmbeddingNetwork(nn.Module):
    def __init__(self, embedding_dim=128):
        super(EmbeddingNetwork, self).__init__()
        
        # Use pre-trained MobileNetV2 as backbone
        mobilenet = models.mobilenet_v2(pretrained=True)
        
        # Remove the classifier
        self.features = mobilenet.features
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        # Embedding head
        self.embedding = nn.Sequential(
            nn.Linear(1280, 512),  # MobileNetV2 outputs 1280 features
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, embedding_dim)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.embedding(x)
        # L2 normalize embeddings
        x = F.normalize(x, p=2, dim=1)
        return x

# Create model
model = EmbeddingNetwork(embedding_dim=EMBEDDING_DIM)
model = model.to(DEVICE)

print("Model created successfully!")
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

## 6. Triplet Loss

Loss function that encourages:
- Anchor-Positive distance to be small
- Anchor-Negative distance to be large
- Margin between them

In [None]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
    
    def forward(self, anchor, positive, negative):
        # Euclidean distance
        pos_dist = F.pairwise_distance(anchor, positive, p=2)
        neg_dist = F.pairwise_distance(anchor, negative, p=2)
        
        # Triplet loss: max(d(a,p) - d(a,n) + margin, 0)
        losses = F.relu(pos_dist - neg_dist + self.margin)
        
        return losses.mean(), pos_dist.mean(), neg_dist.mean()

criterion = TripletLoss(margin=MARGIN)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

print(f"Triplet loss with margin: {MARGIN}")
print(f"Optimizer: Adam (lr={LEARNING_RATE})")

## 7. Training Loop

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    running_pos_dist = 0.0
    running_neg_dist = 0.0
    
    pbar = tqdm(loader, desc='Training')
    for anchor, positive, negative in pbar:
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)
        
        optimizer.zero_grad()
        
        # Get embeddings
        anchor_emb = model(anchor)
        positive_emb = model(positive)
        negative_emb = model(negative)
        
        # Calculate loss
        loss, pos_dist, neg_dist = criterion(anchor_emb, positive_emb, negative_emb)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        running_pos_dist += pos_dist.item()
        running_neg_dist += neg_dist.item()
        
        pbar.set_postfix({
            'loss': running_loss / (pbar.n + 1),
            'pos_dist': running_pos_dist / (pbar.n + 1),
            'neg_dist': running_neg_dist / (pbar.n + 1)
        })
    
    epoch_loss = running_loss / len(loader)
    epoch_pos_dist = running_pos_dist / len(loader)
    epoch_neg_dist = running_neg_dist / len(loader)
    return epoch_loss, epoch_pos_dist, epoch_neg_dist

def validate_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_pos_dist = 0.0
    running_neg_dist = 0.0
    
    with torch.no_grad():
        pbar = tqdm(loader, desc='Validation')
        for anchor, positive, negative in pbar:
            anchor = anchor.to(device)
            positive = positive.to(device)
            negative = negative.to(device)
            
            anchor_emb = model(anchor)
            positive_emb = model(positive)
            negative_emb = model(negative)
            
            loss, pos_dist, neg_dist = criterion(anchor_emb, positive_emb, negative_emb)
            
            running_loss += loss.item()
            running_pos_dist += pos_dist.item()
            running_neg_dist += neg_dist.item()
            
            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'pos_dist': running_pos_dist / (pbar.n + 1),
                'neg_dist': running_neg_dist / (pbar.n + 1)
            })
    
    epoch_loss = running_loss / len(loader)
    epoch_pos_dist = running_pos_dist / len(loader)
    epoch_neg_dist = running_neg_dist / len(loader)
    return epoch_loss, epoch_pos_dist, epoch_neg_dist

# Train the model
print("Starting training...\n")
history = {
    'train_loss': [], 'train_pos_dist': [], 'train_neg_dist': [],
    'val_loss': [], 'val_pos_dist': [], 'val_neg_dist': []
}

best_val_loss = float('inf')

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print("-" * 60)
    
    train_loss, train_pos, train_neg = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    val_loss, val_pos, val_neg = validate_epoch(model, val_loader, criterion, DEVICE)
    
    history['train_loss'].append(train_loss)
    history['train_pos_dist'].append(train_pos)
    history['train_neg_dist'].append(train_neg)
    history['val_loss'].append(val_loss)
    history['val_pos_dist'].append(val_pos)
    history['val_neg_dist'].append(val_neg)
    
    print(f"Train - Loss: {train_loss:.4f}, Pos Dist: {train_pos:.4f}, Neg Dist: {train_neg:.4f}")
    print(f"Val   - Loss: {val_loss:.4f}, Pos Dist: {val_pos:.4f}, Neg Dist: {val_neg:.4f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_DIR / 'best_similarity_model.pth')
        print(f"✓ Saved best model with validation loss: {val_loss:.4f}")
    
    scheduler.step(val_loss)
    print()

print("\n✓ Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")

## 8. Visualize Training Results

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

epochs_range = range(1, len(history['train_loss']) + 1)

# Loss
axes[0].plot(epochs_range, history['train_loss'], label='Training Loss', linewidth=2, marker='o')
axes[0].plot(epochs_range, history['val_loss'], label='Validation Loss', linewidth=2, marker='s')
axes[0].set_title('Triplet Loss', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Distances
axes[1].plot(epochs_range, history['train_pos_dist'], label='Train Positive Distance', linewidth=2, marker='o')
axes[1].plot(epochs_range, history['train_neg_dist'], label='Train Negative Distance', linewidth=2, marker='s')
axes[1].plot(epochs_range, history['val_pos_dist'], label='Val Positive Distance', linewidth=2, marker='^', linestyle='--')
axes[1].plot(epochs_range, history['val_neg_dist'], label='Val Negative Distance', linewidth=2, marker='v', linestyle='--')
axes[1].set_title('Embedding Distances', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Distance')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Training Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Validation Loss: {history['val_loss'][-1]:.4f}")
print(f"\nFinal Positive Distance: {history['val_pos_dist'][-1]:.4f}")
print(f"Final Negative Distance: {history['val_neg_dist'][-1]:.4f}")
print(f"Distance Margin: {history['val_neg_dist'][-1] - history['val_pos_dist'][-1]:.4f}")

## 9. Generate Embeddings for All Images

In [None]:
def get_embedding(model, image_path, transform, device):
    """Get embedding for a single image"""
    model.eval()
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        embedding = model(img_tensor)
    
    return embedding.cpu().numpy()[0]

# Generate embeddings for all images
print("Generating embeddings for all images...")
embeddings_data = []

for img_info in tqdm(all_images, desc='Computing embeddings'):
    embedding = get_embedding(model, img_info['path'], val_transform, DEVICE)
    embeddings_data.append({
        'path': str(img_info['path']),
        'category': img_info['category'],
        'embedding': embedding
    })

print(f"\n✓ Generated {len(embeddings_data)} embeddings")
print(f"Embedding dimension: {embeddings_data[0]['embedding'].shape}")

## 10. Visualize Embeddings with t-SNE

In [None]:
# Prepare data for t-SNE
embeddings_array = np.array([item['embedding'] for item in embeddings_data])
categories_list = [item['category'] for item in embeddings_data]

# Reduce to 2D using t-SNE
print("Running t-SNE dimensionality reduction...")
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
embeddings_2d = tsne.fit_transform(embeddings_array)

# Plot
plt.figure(figsize=(14, 10))
colors = plt.cm.tab10(range(len(categories)))
category_to_color = {cat: colors[i] for i, cat in enumerate(categories)}

for category in categories:
    mask = np.array([cat == category for cat in categories_list])
    plt.scatter(
        embeddings_2d[mask, 0],
        embeddings_2d[mask, 1],
        c=[category_to_color[category]],
        label=category,
        alpha=0.7,
        s=100,
        edgecolors='black',
        linewidth=0.5
    )

plt.title('t-SNE Visualization of Jewelry Embeddings', fontsize=16, fontweight='bold', pad=20)
plt.xlabel('t-SNE Dimension 1', fontsize=12)
plt.ylabel('t-SNE Dimension 2', fontsize=12)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\n✓ Embeddings visualized!")
print("Similar items should cluster together in the plot.")

## 11. Similarity Search: Find Similar Items

In [None]:
def find_similar_items(query_embedding, embeddings_data, top_k=5, exclude_query=True):
    """Find k most similar items to query embedding"""
    query_emb = query_embedding.reshape(1, -1)
    
    # Compute distances
    distances = []
    for i, item in enumerate(embeddings_data):
        dist = np.linalg.norm(query_emb - item['embedding'])
        distances.append((i, dist, item))
    
    # Sort by distance
    distances.sort(key=lambda x: x[1])
    
    # Exclude query itself if needed
    if exclude_query:
        results = distances[1:top_k+1]
    else:
        results = distances[:top_k]
    
    return results

def visualize_similar_items(query_idx, embeddings_data, top_k=5):
    """Visualize query image and its most similar items"""
    query_item = embeddings_data[query_idx]
    query_emb = query_item['embedding']
    
    similar_items = find_similar_items(query_emb, embeddings_data, top_k=top_k)
    
    # Plot
    fig, axes = plt.subplots(1, top_k + 1, figsize=(20, 4))
    
    # Query image
    query_img = Image.open(query_item['path'])
    axes[0].imshow(query_img)
    axes[0].set_title(f"QUERY\n{query_item['category']}", 
                     fontsize=12, fontweight='bold', color='blue')
    axes[0].axis('off')
    axes[0].set_xlabel('Query Image', fontsize=10, fontweight='bold')
    
    # Similar images
    for i, (idx, distance, item) in enumerate(similar_items, 1):
        img = Image.open(item['path'])
        axes[i].imshow(img)
        
        # Color code: green if same category, red if different
        color = 'green' if item['category'] == query_item['category'] else 'red'
        axes[i].set_title(f"{item['category']}\nDist: {distance:.3f}",
                         fontsize=10, color=color)
        axes[i].axis('off')
        axes[i].set_xlabel(f'Rank #{i}', fontsize=9)
    
    plt.suptitle('Image Similarity Search Results (Green=Same Category, Red=Different)', 
                fontsize=14, fontweight='bold', y=1.05)
    plt.tight_layout()
    plt.show()

print("Similarity search functions ready!")

In [None]:
# Demo: Find similar items for random queries
print("Running similarity search demos...\n")

for i in range(3):
    random_idx = np.random.randint(0, len(embeddings_data))
    print(f"Demo {i+1}: Query from category '{embeddings_data[random_idx]['category']}'")
    visualize_similar_items(random_idx, embeddings_data, top_k=5)
    print()

## 12. Evaluation: Retrieval Metrics

In [None]:
def evaluate_retrieval(embeddings_data, k_values=[1, 3, 5, 10]):
    """Evaluate retrieval performance using Precision@K"""
    results = {k: [] for k in k_values}
    
    for query_idx, query_item in enumerate(tqdm(embeddings_data, desc='Evaluating')):
        query_category = query_item['category']
        query_emb = query_item['embedding']
        
        # Find similar items
        for k in k_values:
            similar = find_similar_items(query_emb, embeddings_data, top_k=k, exclude_query=True)
            
            # Calculate precision: how many of top-k are same category?
            correct = sum(1 for _, _, item in similar if item['category'] == query_category)
            precision = correct / k
            results[k].append(precision)
    
    # Average precision
    avg_precision = {k: np.mean(results[k]) for k in k_values}
    
    return avg_precision

# Evaluate
print("Evaluating retrieval performance...\n")
precision_at_k = evaluate_retrieval(embeddings_data, k_values=[1, 3, 5, 10])

print("\nRetrieval Performance (Precision@K):")
print("=" * 40)
for k, prec in precision_at_k.items():
    print(f"Precision@{k}: {prec:.4f} ({prec*100:.2f}%)")

# Plot
plt.figure(figsize=(10, 6))
k_vals = list(precision_at_k.keys())
prec_vals = list(precision_at_k.values())
plt.bar(k_vals, prec_vals, color='skyblue', edgecolor='navy', alpha=0.7)
plt.xlabel('K (Number of Retrieved Items)', fontsize=12)
plt.ylabel('Precision@K', fontsize=12)
plt.title('Retrieval Performance', fontsize=14, fontweight='bold')
plt.ylim(0, 1.0)
plt.grid(axis='y', alpha=0.3)
for i, (k, prec) in enumerate(precision_at_k.items()):
    plt.text(i, prec + 0.02, f'{prec:.3f}', ha='center', fontsize=10, fontweight='bold')
plt.tight_layout()
plt.show()

## 13. Save Model and Embeddings

In [None]:
import json
import pickle

# Save model
final_model_path = MODEL_DIR / 'jewelry_similarity_final.pth'
torch.save(model.state_dict(), final_model_path)
print(f"✓ Model saved to: {final_model_path}")

# Save embeddings
embeddings_path = MODEL_DIR / 'jewelry_embeddings.pkl'
with open(embeddings_path, 'wb') as f:
    pickle.dump(embeddings_data, f)
print(f"✓ Embeddings saved to: {embeddings_path}")

# Save metadata
metadata = {
    'embedding_dim': EMBEDDING_DIM,
    'img_size': IMG_SIZE,
    'categories': categories,
    'num_images': len(embeddings_data),
    'precision_at_k': {str(k): float(v) for k, v in precision_at_k.items()}
}
metadata_path = MODEL_DIR / 'similarity_metadata.json'
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"✓ Metadata saved to: {metadata_path}")

# Model summary
model_size_mb = os.path.getsize(final_model_path) / (1024 * 1024)
embeddings_size_mb = os.path.getsize(embeddings_path) / (1024 * 1024)

print(f"\n{'='*70}")
print("MODEL SUMMARY")
print(f"{'='*70}")
print(f"Task: Image Similarity (Siamese Network)")
print(f"Embedding dimension: {EMBEDDING_DIM}")
print(f"Number of images: {len(embeddings_data)}")
print(f"Categories: {len(categories)}")
print(f"Model size: {model_size_mb:.2f} MB")
print(f"Embeddings size: {embeddings_size_mb:.2f} MB")
print(f"Precision@5: {precision_at_k[5]:.4f}")
print(f"{'='*70}")

## Summary

This notebook demonstrated image similarity learning for jewelry:

1. **Data Preparation**: Created triplet dataset (anchor, positive, negative)
2. **Model Architecture**: Siamese network with MobileNetV2 backbone
3. **Training**: Triplet loss to learn discriminative embeddings
4. **Embeddings**: Generated 128-D vectors for all images
5. **Visualization**: t-SNE plot showing clustering by category
6. **Similarity Search**: Find visually similar items using distance metrics
7. **Evaluation**: Precision@K metrics for retrieval quality

**Key Differences from Classification:**
- No fixed number of classes - can add new items without retraining
- Captures visual similarity beyond category labels
- Useful for recommendation systems
- Can handle new jewelry types not seen during training

**Use Cases:**
- "Find similar items" in e-commerce
- Visual search engines
- Duplicate detection
- Product recommendations
- Clustering and organization