# Embedding Database Creation & k-NN Classification Testing

This notebook:
1. Loads BEiT v2 model from checkpoint
2. Extracts embeddings from train dataset
3. Computes class centroids
4. Selects top-100 samples per class for k-NN database
5. Tests: ArcFace only vs ArcFace + k-NN on validation set
6. Analyzes improvements and failure cases

## 1. Setup & Configuration

In [None]:
import sys
sys.path.append('/home/fishial/Fishial/FishialGithubRepo/fish-identification')

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from collections import defaultdict, Counter
from sklearn.metrics import (
    accuracy_score, 
    classification_report, 
    confusion_matrix,
    top_k_accuracy_score
)
from sklearn.metrics import pairwise_distances
import faiss
import fiftyone as fo
from pathlib import Path
import json

# Import your training modules
from module.classification_package.src.lightning_trainer_fixed import ImageEmbeddingTrainerViT
from module.classification_package.src.datamodule import ImageEmbeddingDataModule

%matplotlib inline
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

def read_json(path):
    with open(path, "r") as f:
        return json.load(f)

In [None]:
# Configuration
CONFIG = {
    'checkpoint_path': '/home/fishial/Fishial/Experiments/v10/beitv2_base_patch16_224.in1k_ft_in22k_in1k_20260127_073527/checkpoints/model-epoch=58-val/accuracy_epoch=0.9498.ckpt',
    'dataset_name': 'classification_v0.10_train',  # Dataset without train/val tags
    'output_dir': '/home/fishial/Fishial/Experiments/v10/embedding_database/beitv2_base_patch16_224.in1k_ft_in22k_in1k_20260127_073527',
    'exclude_classes': ['unset', 'Thunnus obesus'],
    'label_path': '/home/fishial/Fishial/Experiments/v10/beitv2_base_patch16_224.in1k_ft_in22k_in1k_20260127_073527/labels.json',
    'coco_path': '/home/fishial/Fishial/dataset/EXPORT_V_0_9/Fishial_Export_Jan_08_2026_04_14_Production_AI_Gen_All_Verified.json',
    # Database parameters
    'samples_per_class': 100,  # Top-100 most representative samples
    'batch_size': 64,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # k-NN parameters (for testing)
    'topk_centroid': 5,
    'topk_neighbors': 10,
    'centroid_threshold': 0.7,
    'neighbor_threshold': 0.8,
    # Exact match handling
    'exclude_exact_matches': False,
    'exact_match_tolerance': 1e-4,
    
    # Dataset split (only if dataset has no train/val tags)
    'use_full_dataset': True,  # True = load all as train, False = use tags
}

# Create output directory
Path(CONFIG['output_dir']).mkdir(parents=True, exist_ok=True)
print(f"Device: {CONFIG['device']}")
print(f"Checkpoint: {Path(CONFIG['checkpoint_path']).name}")
print(f"Dataset: {CONFIG['dataset_name']} (full dataset: {CONFIG['use_full_dataset']})")

## 2. Load Model & Dataset

In [None]:
# Load model from Lightning checkpoint
print("Loading model...")
model = ImageEmbeddingTrainerViT.load_from_checkpoint(
    CONFIG['checkpoint_path'],
    map_location=CONFIG['device']
).eval()

print(f"Model loaded!")
print(f"  Embedding dim: {model.hparams.embedding_dim}")
print(f"  Num classes: {model.hparams.num_classes}")

coco = read_json(CONFIG['coco_path'])
labels_dict = read_json(CONFIG['label_path'])

label_to_species_id = {
    c['supercategory']: c.get('fishial_extra', {}).get('species_id')
    for c in coco.get('categories', [])
    if c.get('name') == 'General body shape' and 'fishial_extra' in c
}

labels_keys = { int(label_id): {
    "label": labels_dict[label_id],
    "species_id": label_to_species_id[labels_dict[label_id]]
} for label_id in labels_dict}

In [None]:
# Create label_to_name mapping for conflict analysis
# labels_dict format: {label_id (int): class_name (str), ...}
# We need: {label_id (int): class_name (str)} - same structure
label_to_name = {int(k): v for k, v in labels_dict.items()}

print(f"Created label_to_name mapping with {len(label_to_name)} classes")
print(f"Example: label 0 = '{label_to_name.get(0, 'N/A')}'")
print(f"Example: label 1 = '{label_to_name.get(1, 'N/A')}'")
print(f"Example: label 2 = '{label_to_name.get(2, 'N/A')}'")

In [None]:
# Load dataset
print("Loading dataset...")
datamodule = ImageEmbeddingDataModule(
    dataset_name=CONFIG['dataset_name'],
    batch_size=CONFIG['batch_size'],
    classes_per_batch=32,  # Not used for inference, but required by DataModule
    samples_per_class=6,   # Not used for inference, but required by DataModule
    image_size=224,
    exclude_classes=CONFIG['exclude_classes'],
    augmentation_preset='basic',  # Use 'basic'
    train_tag=None,  # None = load all samples into train
    val_tag="val",    # None = no validation split
    class_mapping_path=CONFIG['label_path'],
    num_workers=4,
)
datamodule.setup('fit')

print(f"Dataset loaded!")
print(f"  Train samples: {len(datamodule.train_dataset)}")
if datamodule.val_dataset is not None:
    print(f"  Val samples: {len(datamodule.val_dataset)}")
else:
    print(f"  Val samples: 0 (no validation split)")
print(f"  Num classes: {len(labels_keys)}")

## 3. Extract Embeddings from Train Set

In [None]:
def create_sequential_dataloader(dataset, batch_size, num_workers=4):
    """Create a sequential (non-shuffled) DataLoader for embedding extraction."""
    from torch.utils.data import DataLoader
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,  # Important: sequential order
        num_workers=num_workers,
        pin_memory=True
    )

def extract_embeddings(model, dataloader, device):
    """
    Extract normalized embeddings and ArcFace logits from model.
    """
    embeddings = []
    labels = []
    logits_list = []
    image_ids = []
    annotation_ids = []
    drawn_fish_ids = []
    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Extracting embeddings'):
            images, targets, _, metadata = batch
            images = images.to(device)
            
            # Get embeddings and logits
            emb_norm, logits, _ = model(images)
            
            embeddings.append(emb_norm.cpu())
            logits_list.append(logits.cpu())
            labels.extend(targets.cpu().numpy())
            
            # Metadata
            if 'image_id' in metadata:
                image_ids.extend(metadata['image_id'])
            if 'annotation_id' in metadata:
                annotation_ids.extend(metadata['annotation_id'])
            if 'drawn_fish_id' in metadata:
                drawn_fish_ids.extend(metadata['drawn_fish_id'])
    
    embeddings = torch.cat(embeddings, dim=0).numpy()
    logits = torch.cat(logits_list, dim=0).numpy()
    labels = np.array(labels)
    
    return {
        'embeddings': embeddings,
        'logits': logits,
        'labels': labels,
        'image_ids': image_ids if image_ids else None,
        'annotation_ids': annotation_ids if annotation_ids else None,
        'drawn_fish_ids': drawn_fish_ids if drawn_fish_ids else None,
    }

In [None]:
# Extract train embeddings
# Note: We use a sequential dataloader (not MPerClassSampler) to preserve order
print("Extracting train embeddings...")
train_dataloader = create_sequential_dataloader(
    datamodule.train_dataset, 
    batch_size=CONFIG['batch_size'],
    num_workers=4
)

In [None]:

train_data = extract_embeddings(
    model, 
    train_dataloader, 
    CONFIG['device']
)

print(f"Extracted {len(train_data['embeddings'])} train embeddings")
print(f"Embedding shape: {train_data['embeddings'].shape}")

In [None]:
# Extract validation embeddings (if available)
if datamodule.val_dataset is not None:
    print("Extracting validation embeddings...")
    val_data = extract_embeddings(
        model, 
        datamodule.val_dataloader(), 
        CONFIG['device']
    )
    print(f"Extracted {len(val_data['embeddings'])} validation embeddings")
else:
    print("‚ö†Ô∏è No validation split found.")
    print("For testing, we'll use a random 10% subset of train data.")
    
    # Split train data for testing
    n_val = len(train_data['embeddings']) // 10
    indices = np.random.RandomState(42).permutation(len(train_data['embeddings']))
    val_indices = indices[:n_val]
    
    val_data = {
        'embeddings': train_data['embeddings'][val_indices],
        'logits': train_data['logits'][val_indices],
        'labels': train_data['labels'][val_indices],
        'image_ids': [train_data['image_ids'][i] for i in val_indices] if train_data['image_ids'] else None,
        'annotation_ids': [train_data['annotation_ids'][i] for i in val_indices] if train_data['annotation_ids'] else None,
        'drawn_fish_ids': [train_data['drawn_fish_ids'][i] for i in val_indices] if train_data['drawn_fish_ids'] else None,
    }
    print(f"Created validation subset: {len(val_data['embeddings'])} samples")

## 4. Compute Class Centroids

In [None]:
def compute_centroids(embeddings, labels):
    """
    Compute normalized centroid for each class.
    """
    unique_labels = np.unique(labels)
    centroids = {}
    
    for label in tqdm(unique_labels, desc='Computing centroids'):
        class_embeddings = embeddings[labels == label]
        centroid = np.mean(class_embeddings, axis=0)
        # Normalize
        centroid /= (np.linalg.norm(centroid) + 1e-10)
        centroids[label] = centroid
    
    return centroids

In [None]:
# Compute centroids
centroids = compute_centroids(train_data['embeddings'], train_data['labels'])
print(f"Computed {len(centroids)} class centroids")

# Convert to matrix for efficient computation
centroid_labels = list(centroids.keys())
centroid_matrix = np.stack([centroids[label] for label in centroid_labels])
print(f"Centroid matrix shape: {centroid_matrix.shape}")

## 5. Select Top-100 Representative Samples per Class

In [None]:
# def select_representative_samples(embeddings, labels, metadata, centroids, samples_per_class=100):
#     """
#     Select top-N most representative samples per class (closest to centroid).
    
#     ‚ö†Ô∏è DEPRECATED: This is the old version without conflict filtering.
#     Use select_representative_samples_v2() instead for better results.
#     """
#     selected_indices = []
#     unique_labels = np.unique(labels)
    
#     for label in tqdm(unique_labels, desc='Selecting samples'):
#         class_mask = labels == label
#         class_embeddings = embeddings[class_mask]
#         class_indices = np.where(class_mask)[0]
        
#         # Compute distances to centroid
#         centroid = centroids[label]
#         distances = 1.0 - np.dot(class_embeddings, centroid)  # Cosine distance
        
#         # Select top-N closest to centroid
#         n_select = min(samples_per_class, len(class_indices))
#         top_n_local_indices = np.argsort(distances)[:n_select]
#         top_n_global_indices = class_indices[top_n_local_indices]
        
#         selected_indices.extend(top_n_global_indices)
    
#     selected_indices = np.array(selected_indices)
    
#     # Create filtered database
#     db = {
#         'embeddings': embeddings[selected_indices],
#         'labels': labels[selected_indices],
#     }
    
#     # Add metadata if available
#     for key in ['image_ids', 'annotation_ids', 'drawn_fish_ids']:
#         if key in metadata and metadata[key] is not None:
#             db[key] = [metadata[key][i] for i in selected_indices]
    
#     return db, selected_indices

def filter_conflicting_embeddings(embeddings, labels, metadata, 
                                   similarity_threshold=0.8, 
                                   min_neighbors_check=5,
                                   verbose=True):
    """
    Filter out embeddings that are too similar to samples from other classes.
    
    This helps remove potential labeling errors where identical/similar images 
    are assigned to different classes.
    
    Args:
        embeddings: numpy array of embeddings (N, D)
        labels: numpy array of labels (N,)
        metadata: dict with 'image_ids', 'annotation_ids', 'drawn_fish_ids'
        similarity_threshold: cosine similarity threshold (0.95 = very similar)
        min_neighbors_check: how many nearest neighbors to check
        verbose: print statistics
    
    Returns:
        Filtered embeddings, labels, metadata, and list of removed indices
    """
    N = len(embeddings)
    
    # Build FAISS index for fast similarity search
    if verbose:
        print(f"Building FAISS index for {N} embeddings...")
    
    # Normalize embeddings for cosine similarity
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    
    # Create FAISS index
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)  # Inner Product = cosine similarity (after normalization)
    index.add(normalized_embeddings.astype('float32'))
    
    # Find conflicting samples
    if verbose:
        print(f"Searching for conflicting embeddings (similarity > {similarity_threshold})...")
    
    k = min_neighbors_check + 1  # +1 because first neighbor is itself
    similarities, indices = index.search(normalized_embeddings.astype('float32'), k)
    
    conflicting_indices = set()
    conflict_details = []
    
    for i in tqdm(range(N), desc='Checking conflicts', disable=not verbose):
        my_label = labels[i]
        
        # Check neighbors (skip first one - it's the sample itself)
        for j_local, (neighbor_idx, sim) in enumerate(zip(indices[i][1:], similarities[i][1:]), 1):
            neighbor_label = labels[neighbor_idx]
            
            # If very similar but different class - mark as conflict
            if sim >= similarity_threshold and neighbor_label != my_label:
                conflicting_indices.add(i)
                conflicting_indices.add(neighbor_idx)
                
                conflict_details.append({
                    'idx1': i,
                    'idx2': neighbor_idx,
                    'label1': my_label,
                    'label2': neighbor_label,
                    'similarity': sim,
                    'image_id1': metadata['image_ids'][i] if 'image_ids' in metadata else None,
                    'image_id2': metadata['image_ids'][neighbor_idx] if 'image_ids' in metadata else None,
                    'ann_id1': metadata['annotation_ids'][i] if 'annotation_ids' in metadata else None,
                    'ann_id2': metadata['annotation_ids'][neighbor_idx] if 'annotation_ids' in metadata else None,
                })
                
                if verbose and len(conflict_details) <= 10:  # Print first 10
                    print(f"  Conflict: idx {i} (label {my_label}) <-> idx {neighbor_idx} (label {neighbor_label}), sim={sim:.4f}")
    
    # Create mask for non-conflicting samples
    keep_mask = np.ones(N, dtype=bool)
    keep_mask[list(conflicting_indices)] = False
    
    if verbose:
        print(f"\nConflict Statistics:")
        print(f"  Total samples: {N}")
        print(f"  Conflicting samples: {len(conflicting_indices)} ({len(conflicting_indices)/N*100:.2f}%)")
        print(f"  Remaining samples: {keep_mask.sum()} ({keep_mask.sum()/N*100:.2f}%)")
        print(f"  Unique conflict pairs: {len(conflict_details)}")
    
    # Filter data
    filtered_embeddings = embeddings[keep_mask]
    filtered_labels = labels[keep_mask]
    
    filtered_metadata = {}
    for key in ['image_ids', 'annotation_ids', 'drawn_fish_ids']:
        if key in metadata and metadata[key] is not None:
            if isinstance(metadata[key], np.ndarray):
                filtered_metadata[key] = metadata[key][keep_mask]
            elif isinstance(metadata[key], list):
                filtered_metadata[key] = [metadata[key][i] for i in range(N) if keep_mask[i]]
    
    removed_indices = np.where(~keep_mask)[0]
    
    return {
        'embeddings': filtered_embeddings,
        'labels': filtered_labels,
        'metadata': filtered_metadata,
        'removed_indices': removed_indices,
        'conflict_details': conflict_details,
        'n_original': N,
        'n_filtered': keep_mask.sum()
    }

def select_representative_samples_v2(embeddings, labels, metadata, centroids, 
                                      samples_per_class=100,
                                      filter_conflicts=True,
                                      conflict_similarity_threshold=0.95,
                                      conflict_neighbors_check=5):
    """
    Select top-N most representative samples per class (closest to centroid).
    
    Now with optional conflict filtering to remove potentially mislabeled samples.
    
    Args:
        embeddings: numpy array of embeddings (N, D)
        labels: numpy array of labels (N,)
        metadata: dict with 'image_ids', 'annotation_ids', 'drawn_fish_ids'
        centroids: numpy array of class centroids (num_classes, D)
        samples_per_class: number of samples to select per class
        filter_conflicts: whether to filter conflicting embeddings first
        conflict_similarity_threshold: similarity threshold for conflict detection
        conflict_neighbors_check: number of neighbors to check for conflicts
    
    Returns:
        db: dict with filtered embeddings and labels
        selected_indices: global indices of selected samples (before conflict filtering)
        filter_info: dict with conflict filtering statistics (if filter_conflicts=True)
    """
    filter_info = None
    
    # Step 1: Filter conflicts if requested
    if filter_conflicts:
        print("=" * 60)
        print("STEP 1: Filtering conflicting embeddings")
        print("=" * 60)
        
        filter_result = filter_conflicting_embeddings(
            embeddings, 
            labels, 
            metadata,
            similarity_threshold=conflict_similarity_threshold,
            min_neighbors_check=conflict_neighbors_check,
            verbose=True
        )
        
        embeddings = filter_result['embeddings']
        labels = filter_result['labels']
        metadata = filter_result['metadata']
        filter_info = {
            'removed_indices': filter_result['removed_indices'],
            'conflict_details': filter_result['conflict_details'],
            'n_original': filter_result['n_original'],
            'n_filtered': filter_result['n_filtered']
        }
        
        print(f"\n‚úì Filtered from {filter_result['n_original']} to {filter_result['n_filtered']} samples")
        print()
    
    # Step 2: Select representative samples
    print("=" * 60)
    print("STEP 2: Selecting representative samples per class")
    print("=" * 60)
    
    selected_indices = []
    unique_labels = np.unique(labels)
    
    for label in tqdm(unique_labels, desc='Selecting samples'):
        class_mask = labels == label
        class_embeddings = embeddings[class_mask]
        class_indices = np.where(class_mask)[0]
        
        # Compute distances to centroid
        centroid = centroids[label]
        distances = 1.0 - np.dot(class_embeddings, centroid)  # Cosine distance
        
        # Select top-N closest to centroid
        n_select = min(samples_per_class, len(class_indices))
        top_n_local_indices = np.argsort(distances)[:n_select]
        top_n_global_indices = class_indices[top_n_local_indices]
        
        selected_indices.extend(top_n_global_indices)
    
    selected_indices = np.array(selected_indices)
    
    # Create filtered database
    db = {
        'embeddings': embeddings[selected_indices],
        'labels': labels[selected_indices],
    }
    
    # Add metadata if available
    for key in ['image_ids', 'annotation_ids', 'drawn_fish_ids']:
        if key in metadata and metadata[key] is not None:
            if isinstance(metadata[key], np.ndarray):
                db[key] = metadata[key][selected_indices]
            elif isinstance(metadata[key], list):
                db[key] = [metadata[key][i] for i in selected_indices]
    
    print(f"\n‚úì Selected {len(db['embeddings'])} representative samples")
    
    return db, selected_indices, filter_info

In [None]:
# Apply conflict filtering and select representative samples
database, selected_indices, filter_info = select_representative_samples_v2(
    train_data['embeddings'],
    train_data['labels'],
    {
        'image_ids': train_data['image_ids'],
        'annotation_ids': train_data['annotation_ids'],
        'drawn_fish_ids': train_data['drawn_fish_ids'],
    },
    centroids,
    samples_per_class=CONFIG['samples_per_class'],
    filter_conflicts=True,  # Enable conflict filtering
    conflict_similarity_threshold=0.99,  # Cosine similarity threshold (0.95 = very similar)
    conflict_neighbors_check=5  # Check 5 nearest neighbors
)

print("\n" + "=" * 60)
print("FINAL DATABASE STATISTICS")
print("=" * 60)
print(f"Database embedding shape: {database['embeddings'].shape}")
print(f"Total samples in database: {len(database['embeddings'])}")

# Per-class statistics
label_counts = Counter(database['labels'])
print(f"\nSamples per class distribution:")
print(f"  Min: {min(label_counts.values())}")
print(f"  Max: {max(label_counts.values())}")
print(f"  Mean: {np.mean(list(label_counts.values())):.1f}")
print(f"  Median: {np.median(list(label_counts.values())):.1f}")

if filter_info:
    print(f"\nConflict Filtering Impact:")
    print(f"  Original samples: {filter_info['n_original']}")
    print(f"  After filtering: {filter_info['n_filtered']}")
    print(f"  Removed: {len(filter_info['removed_indices'])} ({len(filter_info['removed_indices'])/filter_info['n_original']*100:.2f}%)")
    print(f"  Conflict pairs found: {len(filter_info['conflict_details'])}")

## 6. Save Database

In [None]:
# Save database
database_path = Path(CONFIG['output_dir']) / 'embedding_database_beitv2_top100.pt'

torch.save({
    'embeddings': torch.from_numpy(database['embeddings']),
    'labels': database['labels'],
    'image_ids': database.get('image_ids'),
    'annotation_ids': database.get('annotation_ids'),
    'drawn_fish_ids': database.get('drawn_fish_ids'),
    'labels_keys': labels_keys,
    'centroids': centroids,
    'config': CONFIG,
}, database_path)

print(f"‚úÖ Database saved to: {database_path}")
print(f"   Size: {database_path.stat().st_size / 1024 / 1024:.1f} MB")

## 7. k-NN Classifier Implementation

In [None]:
class EmbeddingKNNClassifier:
    """
    k-NN classifier using embeddings, centroids, and FAISS.
    """
    def __init__(self, database, centroids, config):
        self.db_embeddings = database['embeddings']
        self.db_labels = database['labels']
        self.centroids = centroids
        self.config = config
        
        # Prepare centroid matrix
        self.centroid_labels = list(centroids.keys())
        self.centroid_matrix = np.stack([centroids[label] for label in self.centroid_labels])
        self.exclude_exact_matches = self.config.get('exclude_exact_matches', False)
        self.exact_match_tolerance = self.config.get('exact_match_tolerance', 1e-4)
        
        print(f"KNN Classifier initialized with {len(self.db_embeddings)} samples")
    
    def predict(self, query_embeddings, return_details=False):
        """
        Predict using centroid filtering + k-NN search.
        """
        if isinstance(query_embeddings, torch.Tensor):
            query_embeddings = query_embeddings.cpu().numpy()
        
        predictions = []
        details = []

        
        for query_emb in query_embeddings:
            # Step 1: Find top-K centroids
            centroid_sims = 1.0 - pairwise_distances(
                query_emb.reshape(1, -1), 
                self.centroid_matrix, 
                metric='cosine'
            )[0]
            
            top_centroid_indices = np.argsort(-centroid_sims)[:self.config['topk_centroid']]
            
            # Filter by threshold
            centroid_scores = {
                self.centroid_labels[idx]: centroid_sims[idx]
                for idx in top_centroid_indices 
                if centroid_sims[idx] >= self.config['centroid_threshold']
            }
            
            if not centroid_scores:
                # Fallback: use top-1 centroid
                best_idx = np.argmax(centroid_sims)
                predictions.append(self.centroid_labels[best_idx])
                if return_details:
                    details.append({'centroid_scores': {}, 'neighbor_votes': {}})
                continue
            
            selected_classes = set(centroid_scores.keys())
            
            # Step 2: Filter database by selected classes
            class_mask = np.isin(self.db_labels, list(selected_classes))
            selected_embeddings = self.db_embeddings[class_mask]
            selected_labels = self.db_labels[class_mask]
            
            if len(selected_embeddings) == 0:
                best_class = max(centroid_scores, key=centroid_scores.get)
                predictions.append(best_class)
                if return_details:
                    details.append({'centroid_scores': centroid_scores, 'neighbor_votes': {}})
                continue
            
            # Step 3: k-NN search with FAISS
            dim = selected_embeddings.shape[1]
            index = faiss.IndexFlatIP(dim)  # Inner product (for normalized vectors = cosine)
            index.add(selected_embeddings.astype('float32'))
            
            k = min(self.config['topk_neighbors'], len(selected_embeddings))
            distances, indices = index.search(query_emb.reshape(1, -1).astype('float32'), k)
            
            # Step 4: Vote from neighbors
            neighbor_votes = defaultdict(lambda: {'count': 0, 'total_sim': 0.0})
            for idx, sim in zip(indices[0], distances[0]):
                if sim >= 1.0:
                    continue
                if sim >= self.config['neighbor_threshold']:
                    label = selected_labels[idx]
                    neighbor_votes[label]['count'] += 1
                    neighbor_votes[label]['total_sim'] += sim
            
            # Step 5: Combine centroid + neighbor scores
            final_scores = {}
            for label in selected_classes:
                centroid_score = centroid_scores.get(label, 0.0)
                neighbor_score = neighbor_votes[label]['total_sim'] if label in neighbor_votes else 0.0
                neighbor_count = neighbor_votes[label]['count'] if label in neighbor_votes else 0
                
                # Weighted combination
                final_scores[label] = (
                    0.3 * centroid_score + 
                    0.7 * (neighbor_score / max(neighbor_count, 1))
                )
            
            # Predict
            best_label = max(final_scores, key=final_scores.get)
            predictions.append(best_label)
            
            if return_details:
                details.append({
                    'centroid_scores': centroid_scores,
                    'neighbor_votes': dict(neighbor_votes),
                    'final_scores': final_scores,
                })
        
        if return_details:
            return np.array(predictions), details
        return np.array(predictions)

In [None]:
# Initialize classifier
knn_classifier = EmbeddingKNNClassifier(
    database=database,
    centroids=centroids,
    config=CONFIG
)

## 8. Testing: ArcFace Only vs ArcFace + k-NN

In [None]:
# Method A: ArcFace only (baseline)
print("Method A: ArcFace head only")
predictions_arcface = np.argmax(val_data['logits'], axis=1)
acc_arcface = accuracy_score(val_data['labels'], predictions_arcface)
print(f"  Accuracy: {acc_arcface:.4f} ({acc_arcface*100:.2f}%)")

# Top-2 accuracy
top2_arcface = top_k_accuracy_score(
    val_data['labels'], 
    val_data['logits'], 
    k=2,
    labels=np.arange(val_data['logits'].shape[1])
)
print(f"  Top-2 Accuracy: {top2_arcface:.4f} ({top2_arcface*100:.2f}%)")

# Top-5 accuracy
top5_arcface = top_k_accuracy_score(
    val_data['labels'], 
    val_data['logits'], 
    k=5,
    labels=np.arange(val_data['logits'].shape[1])
)
print(f"  Top-5 Accuracy: {top5_arcface:.4f} ({top5_arcface*100:.2f}%)")

In [None]:
# Method B: ArcFace + k-NN
print("Method B: ArcFace + k-NN")
predictions_knn, details_knn = knn_classifier.predict(
    val_data['embeddings'], 
    return_details=True
)
acc_knn = accuracy_score(val_data['labels'], predictions_knn)
print(f"  Accuracy: {acc_knn:.4f} ({acc_knn*100:.2f}%)")

# Top-2 accuracy for k-NN
top2_correct = 0
for i, (true_label, detail) in enumerate(zip(val_data['labels'], details_knn)):
    if 'final_scores' in detail:
        # Get top-2 predicted classes
        top2_classes = sorted(detail['final_scores'].items(), key=lambda x: x[1], reverse=True)[:2]
        top2_labels = [cls for cls, score in top2_classes]
        if true_label in top2_labels:
            top2_correct += 1
    else:
        # Fallback: if no final_scores, just check top-1
        if predictions_knn[i] == true_label:
            top2_correct += 1

top2_knn = top2_correct / len(val_data['labels'])
print(f"  Top-2 Accuracy: {top2_knn:.4f} ({top2_knn*100:.2f}%)")

# Improvement
improvement = acc_knn - acc_arcface
improvement_top2 = top2_knn - top2_arcface
print(f"\nüìà Improvement (Top-1): {improvement:+.4f} ({improvement*100:+.2f}%)")
print(f"üìà Improvement (Top-2): {improvement_top2:+.4f} ({improvement_top2*100:+.2f}%)")

if improvement > 0:
    print(f"‚úÖ k-NN helps! {improvement*100:.2f}% better")
elif improvement < -0.001:
    print(f"‚ö†Ô∏è k-NN hurts! {-improvement*100:.2f}% worse")
else:
    print(f"‚û°Ô∏è No significant difference")

## 9. Detailed Analysis

In [None]:
# Per-class accuracy comparison
def compute_per_class_accuracy(y_true, y_pred, labels_keys):
    """Compute per-class accuracy."""
    unique_labels = np.unique(y_true)
    results = []
    
    for label in unique_labels:
        mask = y_true == label
        if mask.sum() == 0:
            continue
        
        acc = accuracy_score(y_true[mask], y_pred[mask])
        count = mask.sum()
        
        class_name = labels_keys[str(int(label))]['label']
        results.append({
            'class_id': label,
            'class_name': class_name,
            'count': count,
            'accuracy': acc,
        })
    
    return pd.DataFrame(results)

In [None]:
# Compute per-class accuracy
acc_arcface_df = compute_per_class_accuracy(
    val_data['labels'], 
    predictions_arcface, 
    labels_keys
).rename(columns={'accuracy': 'acc_arcface'})

acc_knn_df = compute_per_class_accuracy(
    val_data['labels'], 
    predictions_knn, 
    labels_keys
).rename(columns={'accuracy': 'acc_knn'})

# Merge
comparison_df = acc_arcface_df.merge(
    acc_knn_df[['class_id', 'acc_knn']], 
    on='class_id'
)
comparison_df['improvement'] = comparison_df['acc_knn'] - comparison_df['acc_arcface']
comparison_df = comparison_df.sort_values('improvement', ascending=False)

print("\nüìä Per-class accuracy comparison:")
print(f"\nTop 10 classes with MOST improvement (k-NN helps):")
print(comparison_df.head(10)[['class_name', 'count', 'acc_arcface', 'acc_knn', 'improvement']].to_string(index=False))

print(f"\nTop 10 classes with LEAST improvement (k-NN hurts):")
print(comparison_df.tail(10)[['class_name', 'count', 'acc_arcface', 'acc_knn', 'improvement']].to_string(index=False))

In [None]:
# Visualization: Improvement distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Histogram of improvements
axes[0].hist(comparison_df['improvement'], bins=30, edgecolor='black')
axes[0].axvline(0, color='red', linestyle='--', label='No change')
axes[0].set_xlabel('Improvement (k-NN - ArcFace)')
axes[0].set_ylabel('Number of classes')
axes[0].set_title('Distribution of Per-Class Improvements')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Scatter: ArcFace vs k-NN accuracy
axes[1].scatter(
    comparison_df['acc_arcface'], 
    comparison_df['acc_knn'],
    s=comparison_df['count'],
    alpha=0.6,
    c=comparison_df['improvement'],
    cmap='RdYlGn'
)
axes[1].plot([0, 1], [0, 1], 'k--', alpha=0.3, label='y=x')
axes[1].set_xlabel('ArcFace Accuracy')
axes[1].set_ylabel('k-NN Accuracy')
axes[1].set_title('Per-Class Accuracy: ArcFace vs k-NN')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim([0, 1])
axes[1].set_ylim([0, 1])

plt.tight_layout()
plt.savefig(Path(CONFIG['output_dir']) / 'accuracy_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nüìä Classes where k-NN helps (improvement > 0.05): {(comparison_df['improvement'] > 0.05).sum()}")
print(f"üìä Classes where k-NN hurts (improvement < -0.05): {(comparison_df['improvement'] < -0.05).sum()}")

## 10. Analyze Failure Cases

In [None]:
# Find cases where ArcFace was correct but k-NN was wrong
arcface_correct = (predictions_arcface == val_data['labels'])
knn_wrong = (predictions_knn != val_data['labels'])
regression_cases = arcface_correct & knn_wrong

print(f"\nüîç Regression cases (ArcFace ‚úÖ ‚Üí k-NN ‚ùå): {regression_cases.sum()}")

if regression_cases.sum() > 0:
    regression_indices = np.where(regression_cases)[0][:10]  # Show first 10
    
    print("\nExample regression cases:")
    for idx in regression_indices:
        true_label = str(int(val_data['labels'][idx]))
        pred_arcface = predictions_arcface[idx]
        pred_knn = str(int(predictions_knn[idx]))
        
        true_name = labels_keys[true_label]['label']
        pred_name = labels_keys[pred_knn]['label']
        
        print(f"  Sample {idx}: True={true_name}, ArcFace={true_name} ‚úÖ, k-NN={pred_name} ‚ùå")

In [None]:
# Find cases where ArcFace was wrong but k-NN was correct
arcface_wrong = (predictions_arcface != val_data['labels'])
knn_correct = (predictions_knn == val_data['labels'])
improvement_cases = arcface_wrong & knn_correct

print(f"\n‚úÖ Improvement cases (ArcFace ‚ùå ‚Üí k-NN ‚úÖ): {improvement_cases.sum()}")

if improvement_cases.sum() > 0:
    improvement_indices = np.where(improvement_cases)[0][:10]  # Show first 10
    
    print("\nExample improvement cases:")
    for idx in improvement_indices:
        true_label =str(int( val_data['labels'][idx]))
        pred_arcface = str(int(predictions_arcface[idx]))
        pred_knn = predictions_knn[idx]
        
        true_name = labels_keys[true_label]['label']
        pred_arcface_name = labels_keys[pred_arcface]['label']
        
        print(f"  Sample {idx}: True={true_name}, ArcFace={pred_arcface_name} ‚ùå, k-NN={true_name} ‚úÖ")

## 11. Save Results

In [None]:
# Save comparison dataframe
comparison_path = Path(CONFIG['output_dir']) / 'per_class_comparison.csv'
comparison_df.to_csv(comparison_path, index=False)
print(f"‚úÖ Per-class comparison saved to: {comparison_path}")

# Save summary
summary = {
    'checkpoint': CONFIG['checkpoint_path'],
    'database_size': len(database['embeddings']),
    'samples_per_class': CONFIG['samples_per_class'],
    'val_samples': len(val_data['labels']),
    
    'accuracy_arcface': float(acc_arcface),
    'accuracy_knn': float(acc_knn),
    'improvement_absolute': float(improvement),
    'improvement_relative': float(improvement / acc_arcface * 100),
    
    'top5_accuracy_arcface': float(top5_arcface),
    
    'regression_cases': int(regression_cases.sum()),
    'improvement_cases': int(improvement_cases.sum()),
    
    'config': CONFIG,
}

summary_path = Path(CONFIG['output_dir']) / 'test_summary.json'
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)
print(f"‚úÖ Test summary saved to: {summary_path}")

## 12. Final Summary

In [None]:
print("="*60)
print("FINAL SUMMARY")
print("="*60)
print(f"\nüì¶ Database:")
print(f"   Samples: {len(database['embeddings'])}")
print(f"   Classes: {len(centroids)}")
print(f"   Path: {database_path}")
print(f"\nüìä Validation Results:")
print(f"   Val samples: {len(val_data['labels'])}")
print(f"   ArcFace only:    {acc_arcface:.4f} ({acc_arcface*100:.2f}%)")
print(f"   ArcFace + k-NN:  {acc_knn:.4f} ({acc_knn*100:.2f}%)")
print(f"   Improvement:     {improvement:+.4f} ({improvement*100:+.2f}%)")
print(f"\nüìà Analysis:")
print(f"   Classes improved: {(comparison_df['improvement'] > 0).sum()} / {len(comparison_df)}")
print(f"   Regression cases: {regression_cases.sum()}")
print(f"   Improvement cases: {improvement_cases.sum()}")
print(f"   Net gain: {improvement_cases.sum() - regression_cases.sum():+d} samples")
print("\n" + "="*60)

if improvement > 0.001:
    print("\n‚úÖ RECOMMENDATION: Use k-NN! It improves accuracy.")
elif improvement < -0.001:
    print("\n‚ö†Ô∏è WARNING: k-NN hurts accuracy. Stick with ArcFace only.")
    print("   Consider adjusting thresholds or using different k.")
else:
    print("\n‚û°Ô∏è NEUTRAL: No significant difference. Use ArcFace for simplicity.")

## 5b. NEW: Select Representative Samples with Conflict Filtering

This improved version:
1. **Filters conflicts**: Removes embeddings that are very similar but have different labels (potential labeling errors)
2. **Selects representatives**: Chooses top-N samples per class closest to centroids
3. **Provides diagnostics**: Shows statistics about removed conflicts

In [None]:
# Apply conflict filtering and select representative samples
database_v2, selected_indices_v2, filter_info = select_representative_samples_v2(
    train_data['embeddings'],
    train_data['labels'],
    {
        'image_ids': train_data['image_ids'],
        'annotation_ids': train_data['annotation_ids'],
        'drawn_fish_ids': train_data['drawn_fish_ids'],
    },
    centroids,
    samples_per_class=CONFIG['samples_per_class'],
    filter_conflicts=True,  # Enable conflict filtering
    conflict_similarity_threshold=0.8,  # Cosine similarity threshold (0.95 = very similar)
    conflict_neighbors_check=5  # Check 5 nearest neighbors
)

print("\n" + "=" * 60)
print("FINAL DATABASE STATISTICS")
print("=" * 60)
print(f"Database embedding shape: {database_v2['embeddings'].shape}")
print(f"Total samples in database: {len(database_v2['embeddings'])}")

# Per-class statistics
label_counts = Counter(database_v2['labels'])
print(f"\nSamples per class distribution:")
print(f"  Min: {min(label_counts.values())}")
print(f"  Max: {max(label_counts.values())}")
print(f"  Mean: {np.mean(list(label_counts.values())):.1f}")
print(f"  Median: {np.median(list(label_counts.values())):.1f}")

if filter_info:
    print(f"\nConflict Filtering Impact:")
    print(f"  Original samples: {filter_info['n_original']}")
    print(f"  After filtering: {filter_info['n_filtered']}")
    print(f"  Removed: {len(filter_info['removed_indices'])} ({len(filter_info['removed_indices'])/filter_info['n_original']*100:.2f}%)")
    print(f"  Conflict pairs found: {len(filter_info['conflict_details'])}")

### Analyze Conflict Details

Let's investigate which classes have the most conflicts and what are the most common confusion pairs.

In [None]:
if filter_info and len(filter_info['conflict_details']) > 0:
    # Check if label_to_name exists, if not create a fallback
    if 'label_to_name' not in globals():
        print("‚ö†Ô∏è Warning: label_to_name not found. Using label IDs instead of class names.")
        print("   Run the cell that creates label_to_name mapping to see class names.\n")
        label_to_name = {}
    
    # Analyze conflict pairs
    conflict_pairs = defaultdict(int)
    class_conflicts = defaultdict(int)
    
    for conflict in filter_info['conflict_details']:
        label1, label2 = conflict['label1'], conflict['label2']
        # Sort to avoid counting (A, B) and (B, A) separately
        pair = tuple(sorted([label1, label2]))
        conflict_pairs[pair] += 1
        class_conflicts[label1] += 1
        class_conflicts[label2] += 1
    
    # Top conflict pairs
    print("Top 20 Most Common Conflict Pairs:")
    print("-" * 80)
    sorted_pairs = sorted(conflict_pairs.items(), key=lambda x: x[1], reverse=True)
    for (label1, label2), count in sorted_pairs[:20]:
        # Get class names if available (from label_to_name mapping)
        name1 = label_to_name.get(label1, f"Class_{label1}")
        name2 = label_to_name.get(label2, f"Class_{label2}")
        print(f"{name1:40s} <-> {name2:40s} : {count:4d} conflicts")
    
    # Classes with most conflicts
    print("\n" + "=" * 80)
    print("Top 20 Classes with Most Conflicts:")
    print("-" * 80)
    sorted_classes = sorted(class_conflicts.items(), key=lambda x: x[1], reverse=True)
    for label, count in sorted_classes[:20]:
        name = label_to_name.get(label, f"Class_{label}")
        print(f"{name:60s} : {count:4d} conflicts")
    
    # Similarity distribution
    similarities = [c['similarity'] for c in filter_info['conflict_details']]
    print("\n" + "=" * 80)
    print("Conflict Similarity Distribution:")
    print("-" * 80)
    print(f"  Min similarity: {min(similarities):.4f}")
    print(f"  Max similarity: {max(similarities):.4f}")
    print(f"  Mean similarity: {np.mean(similarities):.4f}")
    print(f"  Median similarity: {np.median(similarities):.4f}")
    
    # Plot similarity distribution
    plt.figure(figsize=(10, 5))
    plt.hist(similarities, bins=50, edgecolor='black', alpha=0.7)
    plt.axvline(np.mean(similarities), color='red', linestyle='--', label=f'Mean: {np.mean(similarities):.4f}')
    plt.xlabel('Cosine Similarity')
    plt.ylabel('Number of Conflicts')
    plt.title('Distribution of Similarity Scores for Conflicting Embeddings')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
    
else:
    print("No conflicts found or conflict filtering was not enabled.")

### Save Conflict Report

Save detailed conflict information for manual inspection and potential dataset cleaning.

In [None]:
if filter_info and len(filter_info['conflict_details']) > 0:
    # Check if label_to_name exists, if not create a fallback
    if 'label_to_name' not in globals():
        print("‚ö†Ô∏è Warning: label_to_name not found. Using label IDs instead of class names.")
        print("   Run the cell that creates label_to_name mapping to see class names.\n")
        label_to_name = {}
    
    # Prepare conflict report
    conflict_report = []
    
    for conflict in filter_info['conflict_details']:
        # Add class names
        conflict_with_names = conflict.copy()
        conflict_with_names['class_name1'] = label_to_name.get(conflict['label1'], f"Class_{conflict['label1']}")
        conflict_with_names['class_name2'] = label_to_name.get(conflict['label2'], f"Class_{conflict['label2']}")
        conflict_report.append(conflict_with_names)
    
    # Convert to DataFrame for better inspection
    conflict_df = pd.DataFrame(conflict_report)
    
    # Sort by similarity (highest first)
    conflict_df = conflict_df.sort_values('similarity', ascending=False)
    
    # Save to CSV
    conflict_csv_path = Path(CONFIG['output_dir']) / 'conflict_report.csv'
    conflict_df.to_csv(conflict_csv_path, index=False)
    print(f"‚úì Saved conflict report to: {conflict_csv_path}")
    
    # Save summary statistics
    summary = {
        'total_original_samples': filter_info['n_original'],
        'total_filtered_samples': filter_info['n_filtered'],
        'total_removed_samples': len(filter_info['removed_indices']),
        'removal_percentage': len(filter_info['removed_indices']) / filter_info['n_original'] * 100,
        'total_conflict_pairs': len(filter_info['conflict_details']),
        'similarity_threshold': 0.95,
        'neighbors_checked': 5,
        'mean_conflict_similarity': float(np.mean([c['similarity'] for c in filter_info['conflict_details']])),
        'max_conflict_similarity': float(np.max([c['similarity'] for c in filter_info['conflict_details']])),
    }
    
    summary_path = Path(CONFIG['output_dir']) / 'conflict_summary.json'
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    print(f"‚úì Saved conflict summary to: {summary_path}")
    
    # Display first few conflicts
    print("\nTop 10 Most Similar Conflicts:")
    print("-" * 120)
    display_cols = ['class_name1', 'class_name2', 'similarity', 'image_id1', 'image_id2', 'ann_id1', 'ann_id2']
    print(conflict_df[display_cols].head(10).to_string(index=False))
    
else:
    print("No conflicts to save.")

### üìù Parameter Tuning Guide

**Conflict Filtering Parameters:**

1. **`conflict_similarity_threshold`** (default: 0.95)
   - Range: 0.0 to 1.0 (cosine similarity)
   - Higher = stricter (only removes very similar conflicts)
   - Lower = more aggressive (removes more samples)
   - **Recommended values:**
     - `0.99`: Very conservative - only identical/near-identical images
     - `0.95`: Balanced - catches most labeling errors (RECOMMENDED)
     - `0.90`: Aggressive - may remove valid samples
     - `0.85`: Very aggressive - use with caution

2. **`conflict_neighbors_check`** (default: 5)
   - How many nearest neighbors to check for each sample
   - Higher = more thorough but slower
   - **Recommended values:**
     - `3-5`: Fast, catches obvious conflicts (RECOMMENDED)
     - `10-20`: More thorough, slower
     - `50+`: Very thorough, much slower

**When to adjust:**
- **High conflict rate (>5%)**: Increase threshold to 0.97-0.99
- **Low conflict rate (<0.5%)**: Decrease threshold to 0.90-0.93
- **Want more aggressive cleaning**: Decrease threshold, increase neighbors
- **Want conservative cleaning**: Increase threshold, decrease neighbors

**Tips:**
- Start with default values (0.95, 5)
- Check conflict report to see actual similarity distributions
- Manually inspect top conflicts to verify they are real errors
- Adjust threshold based on your data quality

---

## üéØ Quick Start Summary

**New Workflow with Conflict Filtering:**

```python
# 1. Apply conflict filtering + select representatives (all-in-one)
database_v2, selected_indices_v2, filter_info = select_representative_samples_v2(
    train_data['embeddings'],
    train_data['labels'],
    {'image_ids': train_data['image_ids'], ...},
    centroids,
    samples_per_class=100,
    filter_conflicts=True,              # Enable filtering
    conflict_similarity_threshold=0.95,  # Adjust based on data
    conflict_neighbors_check=5
)

# 2. Check results
print(f"Final database size: {len(database_v2['embeddings'])}")
print(f"Conflicts removed: {len(filter_info['removed_indices'])}")

# 3. Analyze and save
# - Check conflict_report.csv for detailed conflict pairs
# - Check conflict_summary.json for statistics
```

**Benefits:**
- ‚úÖ Removes potentially mislabeled samples
- ‚úÖ Improves database quality
- ‚úÖ Reduces confusion between similar classes
- ‚úÖ Provides detailed diagnostics

**Files Generated:**
- `conflict_report.csv` - Detailed list of all conflicts with similarity scores
- `conflict_summary.json` - Overall statistics

### (Optional) Visualize Conflict Pairs in FiftyOne

Use this to manually inspect conflict pairs and decide if they are true labeling errors.

In [None]:
def create_conflict_view_in_fiftyone(conflict_details, dataset_name, max_conflicts=50):
    """
    Create a FiftyOne view showing conflicting image pairs side-by-side.
    
    Args:
        conflict_details: list of conflict dicts from filter_info
        dataset_name: name of the FiftyOne dataset
        max_conflicts: maximum number of conflicts to show
    """
    # Load dataset
    dataset = fo.load_dataset(dataset_name)
    
    print(f"Dataset: {dataset_name}")
    print(f"Total samples in dataset: {len(dataset)}")
    
    # Collect unique annotation IDs and image IDs from conflicts
    conflict_ann_ids = set()
    conflict_image_ids = set()
    
    for conflict in conflict_details[:max_conflicts]:
        if conflict.get('ann_id1'):
            conflict_ann_ids.add(str(conflict['ann_id1']))
        if conflict.get('ann_id2'):
            conflict_ann_ids.add(str(conflict['ann_id2']))
        if conflict.get('image_id1'):
            conflict_image_ids.add(int(conflict['image_id1']))
        if conflict.get('image_id2'):
            conflict_image_ids.add(int(conflict['image_id2']))
    
    print(f"Looking for {len(conflict_ann_ids)} annotation IDs in {len(conflict_image_ids)} images")
    
    # Try to find samples - check multiple possible fields
    view = None
    
    # Strategy 1: Try filtering by sample-level annotation_id field
    try:
        if len(conflict_ann_ids) > 0:
            sample = dataset.first()
            if sample and hasattr(sample, 'annotation_id'):
                # Filter by annotation_id field at sample level
                view = dataset.match(fo.ViewField("annotation_id").is_in(list(conflict_ann_ids)))
                if len(view) > 0:
                    print(f"‚úì Using sample field 'annotation_id'")
    except Exception as e:
        print(f"  Strategy 1 (annotation_id) failed: {e}")
    
    # Strategy 2: Try filtering by sample-level image_id field (as string)
    if view is None or len(view) == 0:
        try:
            if len(conflict_image_ids) > 0:
                sample = dataset.first()
                if sample and hasattr(sample, 'image_id'):
                    # Convert image_ids to strings
                    image_id_strings = [str(img_id) for img_id in conflict_image_ids]
                    view = dataset.match(fo.ViewField("image_id").is_in(image_id_strings))
                    if len(view) > 0:
                        print(f"‚úì Using sample field 'image_id' (as string)")
        except Exception as e:
            print(f"  Strategy 2 (image_id string) failed: {e}")
    
    # Strategy 3: Try filtering by detections.id (annotation IDs stored in detections)
    if view is None or len(view) == 0:
        try:
            if len(conflict_ann_ids) > 0:
                # Check if dataset has detections with id field
                sample = dataset.first()
                if sample and hasattr(sample, 'detections') and sample.detections:
                    detection_field = 'detections'
                    if hasattr(sample.detections, 'detections') and len(sample.detections.detections) > 0:
                        first_det = sample.detections.detections[0]
                        if hasattr(first_det, 'id'):
                            # Filter by detection IDs
                            view = dataset.filter_labels(
                                detection_field,
                                fo.ViewField("id").is_in(list(conflict_ann_ids))
                            )
                            print(f"‚úì Using detection field '{detection_field}' with annotation IDs")
        except Exception as e:
            print(f"  Strategy 3 (detections.id) failed: {e}")
    
    # Strategy 4: Try filtering by image_id in metadata
    if view is None or len(view) == 0:
        try:
            if len(conflict_image_ids) > 0:
                # Try metadata.image_id (as both string and int)
                for converter in [str, int]:
                    try:
                        converted_ids = [converter(img_id) for img_id in conflict_image_ids]
                        view = dataset.match(fo.ViewField("metadata.image_id").is_in(converted_ids))
                        if len(view) > 0:
                            print(f"‚úì Using metadata.image_id field (as {converter.__name__})")
                            break
                    except:
                        pass
        except Exception as e:
            print(f"  Strategy 4 (metadata.image_id) failed: {e}")
    
    # Strategy 5: Try filtering by coco_id
    if view is None or len(view) == 0:
        try:
            if len(conflict_image_ids) > 0:
                view = dataset.match(fo.ViewField("coco_id").is_in(list(conflict_image_ids)))
                if len(view) > 0:
                    print(f"‚úì Using coco_id field")
        except Exception as e:
            print(f"  Strategy 5 (coco_id) failed: {e}")
    
    # Strategy 4: Check available fields and suggest manual approach
    if view is None or len(view) == 0:
        print("\n‚ö†Ô∏è Could not automatically find samples. Dataset schema:")
        sample = dataset.first()
        if sample:
            print(f"  Available sample fields: {list(sample.field_names)}")
            if hasattr(sample, 'metadata'):
                print(f"  Metadata fields: {list(sample.metadata.field_names) if hasattr(sample.metadata, 'field_names') else 'N/A'}")
            if hasattr(sample, 'detections'):
                print(f"  Has detections: Yes")
        
        print("\n  Showing first few conflict annotation IDs:")
        for i, ann_id in enumerate(list(conflict_ann_ids)[:5]):
            print(f"    - {ann_id}")
        
        print("\n  Please check your dataset structure and modify the filter accordingly.")
        return dataset.limit(0)  # Empty view
    
    print(f"\n‚úì Found {len(view)} samples with conflicts")
    
    # Generate color palette for conflict pairs
    import colorsys
    def generate_colors(n):
        """Generate N visually distinct colors in hex format."""
        colors = []
        for i in range(n):
            hue = i / n
            # Use high saturation and medium lightness for vibrant colors
            rgb = colorsys.hsv_to_rgb(hue, 0.8, 0.9)
            hex_color = '#{:02x}{:02x}{:02x}'.format(
                int(rgb[0] * 255), 
                int(rgb[1] * 255), 
                int(rgb[2] * 255)
            )
            colors.append(hex_color)
        return colors
    
    # Generate colors for all pairs
    num_pairs = min(len(conflict_details), max_conflicts)
    pair_colors = generate_colors(num_pairs)
    
    # Build annotation ID to conflict info mapping with unique pair IDs
    ann_id_to_conflict = {}
    for pair_idx, conflict in enumerate(conflict_details[:max_conflicts]):
        ann_id1 = str(conflict.get('ann_id1', ''))
        ann_id2 = str(conflict.get('ann_id2', ''))
        
        # Create unique pair identifier
        pair_id = f"pair_{pair_idx + 1:02d}"
        pair_tag = f"conflict_{pair_id}"
        pair_color = pair_colors[pair_idx]
        
        if ann_id1:
            ann_id_to_conflict[ann_id1] = {
                'pair_ann_id': ann_id2,
                'pair_label': conflict.get('class_name2', str(conflict.get('label2', 'Unknown'))),
                'similarity': conflict['similarity'],
                'own_label': conflict.get('class_name1', str(conflict.get('label1', 'Unknown'))),
                'pair_id': pair_id,
                'pair_tag': pair_tag,
                'pair_number': pair_idx + 1,
                'pair_color': pair_color,
                'role': 'A'  # First sample in pair
            }
        if ann_id2:
            ann_id_to_conflict[ann_id2] = {
                'pair_ann_id': ann_id1,
                'pair_label': conflict.get('class_name1', str(conflict.get('label1', 'Unknown'))),
                'similarity': conflict['similarity'],
                'own_label': conflict.get('class_name2', str(conflict.get('label2', 'Unknown'))),
                'pair_id': pair_id,
                'pair_tag': pair_tag,
                'pair_number': pair_idx + 1,
                'pair_color': pair_color,
                'role': 'B'  # Second sample in pair
            }
    
    print(f"‚úì Created {num_pairs} unique conflict pairs with distinct colors")
    
    # Add conflict information to samples/detections
    tagged_count = 0
    pairs_found = set()
    
    for sample in view:
        modified = False
        
        # Strategy A: Check if sample has annotation_id field (sample-level)
        if hasattr(sample, 'annotation_id'):
            ann_id = str(sample.annotation_id) if sample.annotation_id else None
            if ann_id and ann_id in ann_id_to_conflict:
                info = ann_id_to_conflict[ann_id]
                
                # Add custom fields to sample
                sample['conflict_similarity'] = info['similarity']
                sample['conflict_pair_label'] = info['pair_label']
                sample['conflict_own_label'] = info['own_label']
                sample['conflict_pair_ann_id'] = info['pair_ann_id']
                sample['conflict_pair_id'] = info['pair_id']
                sample['conflict_pair_number'] = info['pair_number']
                sample['conflict_role'] = info['role']
                sample['conflict_color'] = info['pair_color']
                
                # Add unique pair tag
                pair_tag = info['pair_tag']
                if pair_tag not in sample.tags:
                    sample.tags.append(pair_tag)
                
                # Add general conflict tag
                if 'conflict' not in sample.tags:
                    sample.tags.append('conflict')
                
                pairs_found.add(info['pair_id'])
                modified = True
                tagged_count += 1
        
        # Strategy B: Check detections (detection-level)
        elif hasattr(sample, 'detections') and sample.detections:
            for det in sample.detections.detections:
                det_id = str(det.id) if hasattr(det, 'id') else None
                if det_id and det_id in ann_id_to_conflict:
                    info = ann_id_to_conflict[det_id]
                    
                    # Add custom fields
                    det['conflict_similarity'] = info['similarity']
                    det['conflict_pair_label'] = info['pair_label']
                    det['conflict_pair_id'] = info['pair_id']
                    det['conflict_pair_number'] = info['pair_number']
                    det['is_conflict'] = True
                    
                    # Add tags
                    pair_tag = info['pair_tag']
                    if pair_tag not in det.tags:
                        det.tags.append(pair_tag)
                    if 'conflict' not in det.tags:
                        det.tags.append('conflict')
                    
                    pairs_found.add(info['pair_id'])
                    modified = True
                    tagged_count += 1
            
            # Tag sample if any detection has conflict
            if modified and 'has_conflict' not in sample.tags:
                sample.tags.append('has_conflict')
        
        # Save changes
        if modified:
            sample.save()
    
    print(f"‚úì Tagged {tagged_count} samples/detections with conflict information")
    print(f"‚úì Found {len(pairs_found)} conflict pairs in the view")
    print(f"\n" + "=" * 80)
    print("HOW TO VIEW CONFLICTS IN FIFTYONE APP:")
    print("=" * 80)
    print(f"\n1Ô∏è‚É£  FILTER BY SPECIFIC PAIR:")
    print(f"   - Use tag filter: 'conflict_pair_01', 'conflict_pair_02', etc.")
    print(f"   - Each pair has a unique tag and color!")
    
    print(f"\n2Ô∏è‚É£  VIEW ALL CONFLICTS:")
    print(f"   - Filter by tag: 'conflict'")
    print(f"   - Sort by 'conflict_pair_number' to group pairs together")
    
    print(f"\n3Ô∏è‚É£  USEFUL FIELDS:")
    print(f"   - conflict_pair_id: Unique pair identifier (e.g., 'pair_01')")
    print(f"   - conflict_pair_number: Numeric pair ID (1, 2, 3, ...)")
    print(f"   - conflict_own_label: This sample's class")
    print(f"   - conflict_pair_label: The conflicting sample's class")
    print(f"   - conflict_similarity: How similar they are (0-1)")
    print(f"   - conflict_role: 'A' or 'B' (which sample in the pair)")
    print(f"   - conflict_color: Unique color for visualization")
    
    print(f"\n4Ô∏è‚É£  COLOR CODING:")
    print(f"   - Each pair has a unique color in 'conflict_color' field")
    print(f"   - Use this to visually identify pairs")
    
    print(f"\nüí° TIP: Sort by 'conflict_pair_number' to see pairs side-by-side!")
    
    return view

# Example usage (uncomment to run):
# if filter_info and len(filter_info['conflict_details']) > 0:
#     conflict_view = create_conflict_view_in_fiftyone(
#         filter_info['conflict_details'],
#         CONFIG['dataset_name'],
#         max_conflicts=50
#     )
#     
#     # Launch FiftyOne App
#     session = fo.launch_app(conflict_view)
#     print("\nInspect conflicts in FiftyOne App!")
#     print("Look at samples with tag:'conflict' and check 'conflict_pair_label' field")

In [None]:
def create_conflict_pairs_visualization(conflict_view):
    """
    Create a nicely organized view of conflict pairs, sorted by pair number.
    
    Args:
        conflict_view: FiftyOne view with tagged conflicts
    
    Returns:
        Sorted view with pairs grouped together
    """
    # Sort by pair number so pairs appear together
    sorted_view = conflict_view.sort_by("conflict_pair_number")
    
    print("=" * 80)
    print("CONFLICT PAIRS SUMMARY")
    print("=" * 80)
    
    # Group by pairs and show summary
    pairs_info = {}
    for sample in sorted_view:
        if hasattr(sample, 'conflict_pair_number'):
            pair_num = sample.conflict_pair_number
            if pair_num not in pairs_info:
                pairs_info[pair_num] = {
                    'samples': [],
                    'labels': set(),
                    'similarity': sample.conflict_similarity if hasattr(sample, 'conflict_similarity') else None,
                    'color': sample.conflict_color if hasattr(sample, 'conflict_color') else None
                }
            pairs_info[pair_num]['samples'].append(sample.id)
            if hasattr(sample, 'conflict_own_label'):
                pairs_info[pair_num]['labels'].add(sample.conflict_own_label)
    
    # Print summary
    for pair_num in sorted(pairs_info.keys()):
        info = pairs_info[pair_num]
        labels_str = " ‚ÜîÔ∏è ".join(sorted(info['labels']))
        print(f"\nüìå Pair {pair_num:02d} (Tag: conflict_pair_{pair_num:02d})")
        print(f"   Classes: {labels_str}")
        print(f"   Similarity: {info['similarity']:.4f}")
        print(f"   Color: {info['color']}")
        print(f"   Samples: {len(info['samples'])}")
    
    print("\n" + "=" * 80)
    print(f"Total: {len(pairs_info)} pairs, {len(sorted_view)} samples")
    print("=" * 80)
    
    return sorted_view


def view_specific_pair(conflict_view, pair_number):
    """
    View a specific conflict pair.
    
    Args:
        conflict_view: FiftyOne view with tagged conflicts
        pair_number: Pair number to view (1, 2, 3, ...)
    
    Returns:
        View containing only the specified pair
    """
    pair_tag = f"conflict_pair_{pair_number:02d}"
    pair_view = conflict_view.match_tags(pair_tag)
    
    print(f"=" * 80)
    print(f"VIEWING CONFLICT PAIR {pair_number}")
    print(f"=" * 80)
    
    if len(pair_view) == 0:
        print(f"‚ö†Ô∏è No samples found for pair {pair_number}")
        return pair_view
    
    # Show pair details
    for i, sample in enumerate(pair_view, 1):
        print(f"\nSample {i}/{len(pair_view)}:")
        print(f"  ID: {sample.id}")
        if hasattr(sample, 'conflict_own_label'):
            print(f"  Label: {sample.conflict_own_label}")
        if hasattr(sample, 'conflict_pair_label'):
            print(f"  Conflicts with: {sample.conflict_pair_label}")
        if hasattr(sample, 'conflict_similarity'):
            print(f"  Similarity: {sample.conflict_similarity:.4f}")
        if hasattr(sample, 'conflict_role'):
            print(f"  Role in pair: {sample.conflict_role}")
    
    return pair_view


# Example usage:
# sorted_view = create_conflict_pairs_visualization(conflict_view)
# session = fo.launch_app(sorted_view)
#
# # Or view specific pair:
# pair_1_view = view_specific_pair(conflict_view, 1)
# session = fo.launch_app(pair_1_view)

#### Troubleshooting: If view is empty

If the conflict view shows 0 samples, try this diagnostic cell below to understand your dataset structure.

In [None]:
# DIAGNOSTIC: Understand dataset structure and find correct field for filtering
if filter_info and len(filter_info['conflict_details']) > 0:
    dataset = fo.load_dataset(CONFIG['dataset_name'])
    
    print("=" * 80)
    print("DATASET DIAGNOSTIC")
    print("=" * 80)
    
    # Get first sample
    sample = dataset.first()
    
    print(f"\n1. Dataset Info:")
    print(f"   Name: {dataset.name}")
    print(f"   Total samples: {len(dataset)}")
    print(f"   Media type: {dataset.media_type}")
    
    print(f"\n2. Sample Fields:")
    for field_name in sample.field_names:
        field_value = getattr(sample, field_name, None)
        field_type = type(field_value).__name__
        print(f"   - {field_name}: {field_type}")
        
        # Show first few values for small fields
        if field_name in ['id', 'coco_id'] and field_value:
            print(f"     Value: {field_value}")
    
    print(f"\n3. Metadata Fields:")
    if hasattr(sample, 'metadata') and sample.metadata:
        for field_name in sample.metadata.field_names:
            field_value = getattr(sample.metadata, field_name, None)
            print(f"   - metadata.{field_name}: {field_value}")
    
    print(f"\n4. Detections Info:")
    if hasattr(sample, 'detections') and sample.detections:
        dets = sample.detections.detections
        print(f"   Number of detections: {len(dets)}")
        if len(dets) > 0:
            first_det = dets[0]
            print(f"   Detection fields:")
            for attr in dir(first_det):
                if not attr.startswith('_'):
                    try:
                        val = getattr(first_det, attr)
                        if not callable(val):
                            print(f"     - {attr}: {val if attr not in ['label', 'id'] else str(val)[:50]}")
                    except:
                        pass
    
    print(f"\n5. Sample Conflict Info:")
    print(f"   First conflict annotation_id1: {filter_info['conflict_details'][0].get('ann_id1')}")
    print(f"   First conflict annotation_id2: {filter_info['conflict_details'][0].get('ann_id2')}")
    print(f"   First conflict image_id1: {filter_info['conflict_details'][0].get('image_id1')}")
    print(f"   First conflict image_id2: {filter_info['conflict_details'][0].get('image_id2')}")
    
    print(f"\n6. Suggested Fix:")
    print(f"   Based on the above, modify create_conflict_view_in_fiftyone() to use the correct field.")
    print(f"   Common patterns:")
    print(f"     - If detections have 'id': Filter by detections.id")
    print(f"     - If samples have 'coco_id': Filter by coco_id matching image_id")
    print(f"     - If metadata has 'image_id': Filter by metadata.image_id")
    
    # Try to find a matching sample manually
    print(f"\n7. Manual Search Test:")
    test_image_id = filter_info['conflict_details'][0].get('image_id1')
    test_ann_id = filter_info['conflict_details'][0].get('ann_id1')
    
    found = False
    
    # Test 1: Search by annotation_id
    if test_ann_id and not found:
        print(f"   Searching for annotation_id={test_ann_id}...")
        try:
            view = dataset.match(fo.ViewField("annotation_id") == str(test_ann_id))
            if len(view) > 0:
                print(f"   ‚úì FOUND using field 'annotation_id': {len(view)} sample(s)")
                sample = view.first()
                print(f"     Sample ID: {sample.id}")
                print(f"     Annotation ID: {sample.annotation_id}")
                if hasattr(sample, 'image_id'):
                    print(f"     Image ID: {sample.image_id}")
                found = True
            else:
                print(f"   ‚úó Not found using field 'annotation_id'")
        except Exception as e:
            print(f"   ‚úó Error with field 'annotation_id': {e}")
    
    # Test 2: Search by image_id (as string)
    if test_image_id and not found:
        print(f"   Searching for image_id={test_image_id} (as string)...")
        try:
            view = dataset.match(fo.ViewField("image_id") == str(test_image_id))
            if len(view) > 0:
                print(f"   ‚úì FOUND using field 'image_id' (string): {len(view)} sample(s)")
                sample = view.first()
                print(f"     Sample ID: {sample.id}")
                if hasattr(sample, 'annotation_id'):
                    print(f"     Annotation ID: {sample.annotation_id}")
                if hasattr(sample, 'image_id'):
                    print(f"     Image ID: {sample.image_id}")
                found = True
            else:
                print(f"   ‚úó Not found using field 'image_id' (string)")
        except Exception as e:
            print(f"   ‚úó Error with field 'image_id' (string): {e}")
    
    # Test 3: Try other fields
    if test_image_id and not found:
        print(f"   Searching for image_id={test_image_id} (trying other fields)...")
        
        for field in ['coco_id', 'metadata.image_id', 'id']:
            try:
                view = dataset.match(fo.ViewField(field) == int(test_image_id))
                if len(view) > 0:
                    print(f"   ‚úì FOUND using field '{field}': {len(view)} sample(s)")
                    sample = view.first()
                    print(f"     Sample ID: {sample.id}")
                    if hasattr(sample, 'detections'):
                        print(f"     Detections: {len(sample.detections.detections) if sample.detections else 0}")
                    found = True
                    break
                else:
                    print(f"   ‚úó Not found using field '{field}'")
            except Exception as e:
                print(f"   ‚úó Error with field '{field}': {e}")
    
    if not found:
        print(f"\n   ‚ö†Ô∏è Could not find any samples. The function should now work correctly!")
                
else:
    print("No conflict info available. Run the conflict filtering first.")

#### üé® NEW: Visualize Conflicts with Color-Coded Pairs

Now with unique colored tags for each conflict pair!

In [None]:
# # üé® STEP 1: Create conflict view with unique colored tags
if filter_info and len(filter_info['conflict_details']) > 0:
    conflict_view = create_conflict_view_in_fiftyone(
        filter_info['conflict_details'],
        CONFIG['dataset_name'],
        max_conflicts=50  # Show top 50 conflicts
    )
    
    if len(conflict_view) > 0:
        print(f"\n‚úÖ Successfully created view with {len(conflict_view)} samples")
        
        # üé® STEP 2: Create organized view sorted by pairs
        sorted_view = create_conflict_pairs_visualization(conflict_view)
        
        # üé® STEP 3: Launch FiftyOne App
        print("\nüöÄ Launching FiftyOne App...")
        session = fo.launch_app(sorted_view)
        
        print("\n" + "=" * 80)
        print("FIFTYONE APP TIPS:")
        print("=" * 80)
        print("\nüîç Filtering:")
        print("   ‚Ä¢ All conflicts: tag = 'conflict'")
        print("   ‚Ä¢ Specific pair: tag = 'conflict_pair_01' (or 02, 03, etc.)")
        print("\nüìä Sorting:")
        print("   ‚Ä¢ Sort by 'conflict_pair_number' to see pairs together")
        print("   ‚Ä¢ Sort by 'conflict_similarity' to see most similar first")
        print("\nüé® Color Coding:")
        print("   ‚Ä¢ Check 'conflict_color' field for each pair's unique color")
        print("   ‚Ä¢ Same color = same conflict pair")
        print("\nüìù Fields to Check:")
        print("   ‚Ä¢ conflict_own_label: This image's class")
        print("   ‚Ä¢ conflict_pair_label: What it conflicts with")
        print("   ‚Ä¢ conflict_similarity: How similar (higher = more similar)")
        print("=" * 80)
    else:
        print("\n‚ö†Ô∏è No conflicts found in the view.")
else:
    print("No conflict info available. Run conflict filtering first.")

#### üîé View Specific Conflict Pairs

Use these commands to focus on specific pairs:

In [None]:
# Example 1: View only conflict pair #1
# pair_1_view = view_specific_pair(conflict_view, 1)
# session = fo.launch_app(pair_1_view)

# Example 2: View conflict pair #2
# pair_2_view = view_specific_pair(conflict_view, 2)
# session = fo.launch_app(pair_2_view)

# Example 3: Iterate through all pairs
# for pair_num in range(1, 11):  # View first 10 pairs
#     print(f"\n{'='*80}")
#     pair_view = view_specific_pair(conflict_view, pair_num)
#     if len(pair_view) > 0:
#         # Uncomment to launch app for each pair:
#         # session = fo.launch_app(pair_view)
#         # input("Press Enter to continue to next pair...")
#         pass

# Example 4: Get all samples from pairs with high similarity (>0.95)
# high_sim_view = conflict_view.match(fo.ViewField("conflict_similarity") > 0.95)
# print(f"Found {len(high_sim_view)} samples with similarity > 0.95")
# session = fo.launch_app(high_sim_view)

# Example 5: Group by conflict between specific classes
# If you want to see all conflicts between two specific classes:
# class1 = "Thunnus albacares"
# class2 = "Thunnus obesus"
# specific_conflict = conflict_view.match(
#     (fo.ViewField("conflict_own_label") == class1) & 
#     (fo.ViewField("conflict_pair_label") == class2) |
#     (fo.ViewField("conflict_own_label") == class2) & 
#     (fo.ViewField("conflict_pair_label") == class1)
# )
# print(f"Conflicts between {class1} and {class2}: {len(specific_conflict)}")
# session = fo.launch_app(specific_conflict)

---

## üé® Summary: Color-Coded Conflict Visualization

### What's New:

1. **üè∑Ô∏è Unique Tags per Pair**
   - Each conflict pair gets a unique tag: `conflict_pair_01`, `conflict_pair_02`, etc.
   - Easy filtering in FiftyOne App

2. **üé® Color Coding**
   - Each pair has a unique color (stored in `conflict_color` field)
   - Generated using HSV color space for maximum visual distinction
   - 50 unique colors for 50 pairs

3. **üìä Rich Metadata**
   - `conflict_pair_id`: Unique identifier (e.g., "pair_01")
   - `conflict_pair_number`: Numeric ID (1, 2, 3, ...)
   - `conflict_own_label`: This sample's class
   - `conflict_pair_label`: Conflicting sample's class
   - `conflict_similarity`: Similarity score (0-1)
   - `conflict_role`: 'A' or 'B' (position in pair)
   - `conflict_color`: Hex color code for visualization

4. **üîß Helper Functions**
   - `create_conflict_pairs_visualization()`: Shows summary of all pairs
   - `view_specific_pair()`: Focus on a single pair

### Benefits:

‚úÖ **Easy Identification**: Instantly see which images belong to the same conflict pair  
‚úÖ **Organized View**: Sort by pair number to see pairs side-by-side  
‚úÖ **Quick Filtering**: Use unique tags to focus on specific pairs  
‚úÖ **Visual Distinction**: Each pair has its own color  
‚úÖ **Detailed Analysis**: Rich metadata for each conflict

### Workflow:

```
1. Run conflict filtering ‚Üí Identifies conflicts
2. Create conflict view ‚Üí Tags with colors
3. Launch FiftyOne App ‚Üí Visual inspection
4. Filter by pair tag ‚Üí Focus on specific conflicts
5. Analyze and fix ‚Üí Update dataset
```

#### Step-by-Step: Using FiftyOne Visualization

**Recommended approach:**

1. **First**, run the diagnostic cell above to understand your dataset structure
2. **Then**, uncomment and run the visualization code below
3. **If view is empty**, check the diagnostic output and modify the function accordingly

In [None]:
if filter_info and len(filter_info['conflict_details']) > 0:
    conflict_view = create_conflict_view_in_fiftyone(
        filter_info['conflict_details'],
        CONFIG['dataset_name'],
        max_conflicts=50
    )
    
    # Launch FiftyOne App
    session = fo.launch_app(conflict_view)
    print("\nInspect conflicts in FiftyOne App!")
    print("Look at samples with tag:'conflict' and check 'conflict_pair_label' field")

## 8b. Weighted Fusion: Grid Search for Optimal Weights

This section finds the optimal weights for combining ArcFace and kNN predictions:
- **Weighted Fusion**: `final_score = Œ± * arcface_score + Œ≤ * knn_score`
- Grid search over different Œ± values (Œ≤ = 1 - Œ±)
- Evaluate on validation set
- Compare with other reranking methods (RRF, hybrid)

In [None]:
# Load validation dataset
print("Loading validation dataset...")
val_dataset = fo.load_dataset('classification_v0.10_val')
print(f"Validation samples: {len(val_dataset)}")

# Prepare data for grid search (sample subset for faster iteration)
sample_size = min(1000, len(val_dataset))  # Use subset for faster tuning
val_samples = val_dataset.take(sample_size)

print(f"\nUsing {sample_size} validation samples for weight optimization")
print("This will take a few minutes...")

In [None]:
def evaluate_weighted_fusion(alpha, beta, val_samples, classifier):
    """
    Evaluate accuracy with given weights for weighted fusion.
    
    Args:
        alpha: Weight for ArcFace scores
        beta: Weight for kNN scores
        val_samples: Validation dataset samples
        classifier: EmbeddingClassifier instance
        
    Returns:
        accuracy: Top-1 accuracy
        top5_accuracy: Top-5 accuracy
    """
    from PIL import Image
    import numpy as np
    
    # Temporarily update classifier weights
    original_alpha = classifier.arcface_weight
    original_beta = classifier.knn_weight
    original_mode = classifier.rerank_mode
    
    classifier.arcface_weight = alpha
    classifier.knn_weight = beta
    classifier.rerank_mode = 'weighted_fusion'
    
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    
    for sample in val_samples:
        try:
            # Load image
            image = Image.open(sample.filepath)
            image_array = np.array(image)
            
            # Get predictions
            results = classifier(image_array)
            
            # Ground truth
            gt_label = sample.polyline.label
            
            # Top-1
            if results[0].name == gt_label:
                correct_top1 += 1
            
            # Top-5
            top5_names = [r.name for r in results[:5]]
            if gt_label in top5_names:
                correct_top5 += 1
            
            total += 1
            
        except Exception as e:
            print(f"Error processing sample: {e}")
            continue
    
    # Restore original weights
    classifier.arcface_weight = original_alpha
    classifier.knn_weight = original_beta
    classifier.rerank_mode = original_mode
    
    top1_acc = correct_top1 / total if total > 0 else 0
    top5_acc = correct_top5 / total if total > 0 else 0
    
    return top1_acc, top5_acc

print("‚úÖ Evaluation function defined")

In [None]:
# Initialize classifier (needed for grid search)
from train_scripts.classification.interpreter_classifier_lightning import EmbeddingClassifier

config = {
    'log_level': 'CRITICAL',  # Reduce verbosity
    'dataset': {'path': Path(CONFIG['output_dir']) / 'embedding_database_beitv2_top100.pt'},
    'model': {
        'checkpoint_path': CONFIG['checkpoint_path'],
        'backbone_model_name': 'beitv2_base_patch16_224.in1k_ft_in22k_in1k',
        'embedding_dim': 512,
        'num_classes': len(train_class_to_id),
        'arcface_s': 64.0,
        'arcface_m': 0.2,
        'pooling_type': 'attention',
        'device': CONFIG['device']
    },
    'use_knn': True,
    'rerank_mode': 'weighted_fusion',
}

print("Initializing classifier...")
classifier = EmbeddingClassifier(config)
print("‚úÖ Classifier initialized")

In [None]:
# Grid Search for optimal weights
print("=" * 60)
print("GRID SEARCH: Finding Optimal Weights")
print("=" * 60)

# Define search space
alphas = np.linspace(0, 1, 21)  # 0.0, 0.05, 0.1, ..., 1.0
results_grid = []

best_top1 = 0
best_top5 = 0
best_alpha = 0
best_beta = 0

print(f"\nTesting {len(alphas)} weight combinations...")
print(f"{'Alpha':>6} {'Beta':>6} {'Top-1':>8} {'Top-5':>8}")
print("-" * 35)

for alpha in tqdm(alphas, desc="Grid Search"):
    beta = 1 - alpha
    
    # Evaluate
    top1_acc, top5_acc = evaluate_weighted_fusion(alpha, beta, val_samples, classifier)
    
    results_grid.append({
        'alpha': alpha,
        'beta': beta,
        'top1_accuracy': top1_acc,
        'top5_accuracy': top5_acc
    })
    
    # Track best
    if top1_acc > best_top1:
        best_top1 = top1_acc
        best_top5 = top5_acc
        best_alpha = alpha
        best_beta = beta
    
    # Print progress
    print(f"{alpha:6.2f} {beta:6.2f} {top1_acc:8.4f} {top5_acc:8.4f}")

print("\n" + "=" * 60)
print("BEST WEIGHTS FOUND:")
print("=" * 60)
print(f"  Œ± (ArcFace weight): {best_alpha:.4f}")
print(f"  Œ≤ (kNN weight):     {best_beta:.4f}")
print(f"  Top-1 Accuracy:     {best_top1:.4f} ({best_top1*100:.2f}%)")
print(f"  Top-5 Accuracy:     {best_top5:.4f} ({best_top5*100:.2f}%)")
print("=" * 60)

In [None]:
# Visualize Grid Search Results
df_grid = pd.DataFrame(results_grid)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Top-1 Accuracy vs Alpha
axes[0].plot(df_grid['alpha'], df_grid['top1_accuracy'], 'b-o', linewidth=2, markersize=6)
axes[0].axvline(best_alpha, color='r', linestyle='--', label=f'Best Œ±={best_alpha:.2f}')
axes[0].axhline(best_top1, color='r', linestyle=':', alpha=0.3)
axes[0].set_xlabel('Œ± (ArcFace Weight)', fontsize=12)
axes[0].set_ylabel('Top-1 Accuracy', fontsize=12)
axes[0].set_title('Weighted Fusion: Top-1 Accuracy vs Weight', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].legend()

# Plot 2: Top-5 Accuracy vs Alpha
axes[1].plot(df_grid['alpha'], df_grid['top5_accuracy'], 'g-s', linewidth=2, markersize=6)
axes[1].axvline(best_alpha, color='r', linestyle='--', label=f'Best Œ±={best_alpha:.2f}')
axes[1].axhline(best_top5, color='r', linestyle=':', alpha=0.3)
axes[1].set_xlabel('Œ± (ArcFace Weight)', fontsize=12)
axes[1].set_ylabel('Top-5 Accuracy', fontsize=12)
axes[1].set_title('Weighted Fusion: Top-5 Accuracy vs Weight', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].legend()

plt.tight_layout()
plt.savefig(Path(CONFIG['output_dir']) / 'weighted_fusion_grid_search.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úÖ Visualization saved to:", Path(CONFIG['output_dir']) / 'weighted_fusion_grid_search.png')

### Compare Reranking Methods

Now compare weighted fusion with optimal weights vs other methods (RRF, Hybrid)

In [None]:
# Compare different reranking methods
print("=" * 60)
print("COMPARISON: Different Reranking Methods")
print("=" * 60)

methods_to_test = [
    ('Weighted Fusion (Optimal)', 'weighted_fusion', {'alpha': best_alpha, 'beta': best_beta}),
    ('Weighted Fusion (Equal)', 'weighted_fusion', {'alpha': 0.5, 'beta': 0.5}),
    ('Weighted Fusion (ArcFace Priority)', 'weighted_fusion', {'alpha': 0.7, 'beta': 0.3}),
    ('Weighted Fusion (kNN Priority)', 'weighted_fusion', {'alpha': 0.3, 'beta': 0.7}),
    ('Reciprocal Rank Fusion', 'rrf', {}),
    ('Hybrid (Original)', 'hybrid', {}),
]

comparison_results = []

for method_name, mode, params in methods_to_test:
    print(f"\nTesting: {method_name}")
    
    # Configure classifier
    classifier.rerank_mode = mode
    if 'alpha' in params:
        classifier.arcface_weight = params['alpha']
        classifier.knn_weight = params['beta']
    
    # Evaluate
    correct_top1 = 0
    correct_top5 = 0
    total = 0
    
    for sample in tqdm(val_samples, desc=method_name, leave=False):
        try:
            from PIL import Image
            import numpy as np
            
            image = Image.open(sample.filepath)
            image_array = np.array(image)
            results = classifier(image_array)
            
            gt_label = sample.polyline.label
            
            # Top-1
            if results[0].name == gt_label:
                correct_top1 += 1
            
            # Top-5
            top5_names = [r.name for r in results[:5]]
            if gt_label in top5_names:
                correct_top5 += 1
            
            total += 1
        except Exception as e:
            continue
    
    top1_acc = correct_top1 / total if total > 0 else 0
    top5_acc = correct_top5 / total if total > 0 else 0
    
    comparison_results.append({
        'method': method_name,
        'mode': mode,
        'top1_accuracy': top1_acc,
        'top5_accuracy': top5_acc,
        'params': str(params)
    })
    
    print(f"  Top-1: {top1_acc:.4f} ({top1_acc*100:.2f}%)")
    print(f"  Top-5: {top5_acc:.4f} ({top5_acc*100:.2f}%)")

# Create comparison DataFrame
df_comparison = pd.DataFrame(comparison_results)
df_comparison = df_comparison.sort_values('top1_accuracy', ascending=False)

print("\n" + "=" * 60)
print("RANKING BY TOP-1 ACCURACY:")
print("=" * 60)
print(df_comparison[['method', 'top1_accuracy', 'top5_accuracy']].to_string(index=False))
print("=" * 60)

In [None]:
# Visualize comparison
fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(df_comparison))
width = 0.35

bars1 = ax.bar(x - width/2, df_comparison['top1_accuracy'], width, 
               label='Top-1 Accuracy', alpha=0.8, color='steelblue')
bars2 = ax.bar(x + width/2, df_comparison['top5_accuracy'], width, 
               label='Top-5 Accuracy', alpha=0.8, color='lightcoral')

# Add value labels on bars
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.3f}',
                ha='center', va='bottom', fontsize=9)

ax.set_xlabel('Reranking Method', fontsize=12, fontweight='bold')
ax.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
ax.set_title('Comparison of Reranking Methods', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(df_comparison['method'], rotation=45, ha='right')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(Path(CONFIG['output_dir']) / 'reranking_methods_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úÖ Comparison visualization saved")

### üìä Summary & Recommendations

Based on the grid search results:

In [None]:
# Save optimal configuration
optimal_config = {
    'reranking': {
        'method': 'weighted_fusion',
        'optimal_weights': {
            'arcface_weight': float(best_alpha),
            'knn_weight': float(best_beta)
        },
        'performance': {
            'top1_accuracy': float(best_top1),
            'top5_accuracy': float(best_top5),
            'validation_samples': len(val_samples)
        }
    },
    'comparison': comparison_results
}

# Save to JSON
config_path = Path(CONFIG['output_dir']) / 'optimal_reranking_config.json'
with open(config_path, 'w') as f:
    json.dump(optimal_config, f, indent=2)

print("=" * 60)
print("‚úÖ OPTIMAL CONFIGURATION SAVED")
print("=" * 60)
print(f"Location: {config_path}")
print(f"\nRecommended config for production:")
print(f"```python")
print(f"config = {{")
print(f"    'rerank_mode': 'weighted_fusion',")
print(f"    'arcface_weight': {best_alpha:.4f},")
print(f"    'knn_weight': {best_beta:.4f},")
print(f"}}")
print(f"```")
print(f"\nExpected performance:")
print(f"  Top-1 Accuracy: {best_top1*100:.2f}%")
print(f"  Top-5 Accuracy: {best_top5*100:.2f}%")
print("=" * 60)

---

## üî¨ Advanced Tips & Best Practices

### Interpreting Conflict Results

**High conflict rate (>5%):**
- May indicate systematic labeling issues
- Consider increasing similarity threshold (0.97-0.99)
- Manually review top conflicts in FiftyOne
- Check if certain annotators/sources have more errors

**Low conflict rate (<0.5%):**
- Good data quality! 
- Can try lower threshold (0.90-0.93) for more aggressive cleaning
- May still have subtle labeling errors

### Common Conflict Patterns

1. **Near-Identical Species**
   - Example: Juvenile vs Adult of same species
   - Solution: May be valid, check biological accuracy
   
2. **Similar Looking Species**
   - Example: Two tuna species with similar appearance
   - Solution: True labeling challenge, may need expert review
   
3. **Data Collection Errors**
   - Example: Same image uploaded multiple times with different labels
   - Solution: These should be removed (true duplicates)
   
4. **Annotation Mistakes**
   - Example: Misclicked label during annotation
   - Solution: Fix in dataset, re-export

### Next Steps After Conflict Filtering

1. **Review conflict_report.csv**
   - Sort by similarity (highest first)
   - Manually check top 20-50 conflicts
   - Identify patterns

2. **Fix Dataset Issues**
   - For confirmed errors, correct labels in your dataset
   - Remove true duplicates
   - Re-export and re-run training

3. **Monitor Impact**
   - Compare model performance before/after filtering
   - Check if problematic class pairs improve
   - Validate on held-out test set

4. **Iterate**
   - Adjust threshold based on results
   - Re-run conflict filtering after dataset fixes
   - Track improvement over time