# HoneyBee Workshop Part 5: Retrieval Evaluation

## Overview
In this workshop, you'll learn how to:
1. Implement similarity-based retrieval using embeddings
2. Evaluate retrieval performance with Precision@k metrics
3. Analyze retrieval failures and confusion patterns
4. Compare retrieval across different modalities

**Duration**: 30 minutes

**Prerequisites**: 
- Completed Parts 1-4 or access to pre-computed embeddings
- Understanding of information retrieval concepts

## 1. Setup and Imports

In [None]:
import os
import sys
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Similarity and evaluation imports
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from sklearn.metrics import adjusted_mutual_info_score
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
np.random.seed(42)

print("Libraries loaded successfully!")

## 2. Load Embeddings and Labels

In [None]:
# Load embeddings (using clinical as example)
local_path = Path("/mnt/f/Projects/HoneyBee/results/shared_data/embeddings")

if local_path.exists():
    print("Loading from local pre-computed embeddings...")
    clinical_emb_path = local_path / "clinical_embeddings_tcga.pkl"
    if clinical_emb_path.exists():
        embeddings_df = pd.read_pickle(clinical_emb_path)
    labels_path = local_path.parent / "patient_cancer_types.csv"
    if labels_path.exists():
        labels_df = pd.read_csv(labels_path)
else:
    # Create mock data
    print("Creating mock data for demonstration...")
    n_samples = 500
    n_features = 768
    
    embeddings_df = pd.DataFrame(
        np.random.randn(n_samples, n_features),
        index=[f"TCGA-{i:04d}" for i in range(n_samples)]
    )
    
    cancer_types = ['BRCA', 'LUAD', 'KIRC', 'THCA', 'PRAD']
    labels_df = pd.DataFrame({
        'patient_id': embeddings_df.index,
        'cancer_type': np.random.choice(cancer_types, n_samples)
    })

# Align data
common_patients = list(set(embeddings_df.index) & set(labels_df['patient_id']))
embeddings = embeddings_df.loc[common_patients].values
labels = labels_df[labels_df['patient_id'].isin(common_patients)]['cancer_type'].values
patient_ids = common_patients

print(f"Dataset shape: {embeddings.shape}")
print(f"Unique cancer types: {np.unique(labels)}")

## 3. Implement Retrieval System

In [None]:
class EmbeddingRetriever:
    """
    Similarity-based retrieval system for embeddings
    """
    def __init__(self, embeddings, labels, metric='cosine'):
        self.embeddings = embeddings
        self.labels = labels
        self.metric = metric
        
        # Build nearest neighbors index
        self.nn = NearestNeighbors(
            n_neighbors=min(50, len(embeddings)),
            metric=metric,
            n_jobs=-1
        )
        self.nn.fit(embeddings)
    
    def retrieve(self, query_idx, k=10):
        """
        Retrieve k most similar items to query
        """
        query = self.embeddings[query_idx].reshape(1, -1)
        distances, indices = self.nn.kneighbors(query, n_neighbors=k+1)
        
        # Remove self from results
        return indices[0][1:], distances[0][1:]
    
    def batch_retrieve(self, query_indices, k=10):
        """
        Retrieve for multiple queries
        """
        results = []
        for idx in query_indices:
            retrieved_indices, distances = self.retrieve(idx, k)
            results.append({
                'query_idx': idx,
                'query_label': self.labels[idx],
                'retrieved_indices': retrieved_indices,
                'retrieved_labels': self.labels[retrieved_indices],
                'distances': distances
            })
        return results

# Initialize retriever
retriever = EmbeddingRetriever(embeddings, labels, metric='cosine')
print("Retrieval system initialized!")

## 4. Calculate Precision@k Metrics

In [None]:
def calculate_precision_at_k(retriever, k_values=[1, 5, 10, 20, 50]):
    """
    Calculate Precision@k for different k values
    """
    n_queries = len(retriever.labels)
    precisions = {k: [] for k in k_values}
    
    print("Evaluating retrieval performance...")
    for query_idx in tqdm(range(n_queries)):
        query_label = retriever.labels[query_idx]
        
        # Retrieve top-k
        max_k = max(k_values)
        retrieved_indices, _ = retriever.retrieve(query_idx, k=max_k)
        retrieved_labels = retriever.labels[retrieved_indices]
        
        # Calculate precision for each k
        for k in k_values:
            if k <= len(retrieved_indices):
                relevant = (retrieved_labels[:k] == query_label).sum()
                precision = relevant / k
                precisions[k].append(precision)
    
    # Calculate mean precision
    mean_precisions = {k: np.mean(p) for k, p in precisions.items()}
    return mean_precisions, precisions

# Calculate Precision@k
k_values = [1, 5, 10, 20, 30]
mean_precisions, all_precisions = calculate_precision_at_k(retriever, k_values)

# Display results
print("\nPrecision@k Results:")
for k, precision in mean_precisions.items():
    print(f"Precision@{k}: {precision:.4f}")

## 5. Visualize Retrieval Performance

In [None]:
# Plot Precision@k curve
plt.figure(figsize=(10, 6))
k_vals = sorted(mean_precisions.keys())
precisions = [mean_precisions[k] for k in k_vals]

plt.plot(k_vals, precisions, 'o-', linewidth=2, markersize=8)
plt.xlabel('k', fontsize=12)
plt.ylabel('Precision@k', fontsize=12)
plt.title('Retrieval Performance: Precision@k', fontsize=14)
plt.grid(True, alpha=0.3)
plt.ylim(0, 1)

# Add value labels
for k, p in zip(k_vals, precisions):
    plt.annotate(f'{p:.3f}', (k, p), textcoords="offset points", 
                xytext=(0,10), ha='center')

plt.tight_layout()
plt.show()

# Box plot by cancer type
plt.figure(figsize=(12, 6))
k = 10  # Focus on Precision@10
precision_by_type = {}

for cancer_type in np.unique(labels):
    type_indices = np.where(labels == cancer_type)[0]
    type_precisions = [all_precisions[k][i] for i in type_indices]
    precision_by_type[cancer_type] = type_precisions

df_box = pd.DataFrame(precision_by_type)
df_box.boxplot(figsize=(10, 6))
plt.ylabel(f'Precision@{k}')
plt.title(f'Retrieval Performance by Cancer Type (Precision@{k})')
plt.xticks(rotation=45)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Analyze Retrieval Failures

In [None]:
# Identify hard cases (low precision queries)
k = 10
precision_threshold = 0.5

hard_cases = []
for idx, precision in enumerate(all_precisions[k]):
    if precision < precision_threshold:
        hard_cases.append({
            'index': idx,
            'patient_id': patient_ids[idx],
            'label': labels[idx],
            'precision': precision
        })

print(f"Found {len(hard_cases)} hard cases (Precision@{k} < {precision_threshold})")

# Analyze confusion patterns
confusion_matrix = np.zeros((len(np.unique(labels)), len(np.unique(labels))))
label_to_idx = {label: idx for idx, label in enumerate(np.unique(labels))}

for query_idx in range(len(labels)):
    query_label = labels[query_idx]
    retrieved_indices, _ = retriever.retrieve(query_idx, k=k)
    retrieved_labels = labels[retrieved_indices]
    
    for ret_label in retrieved_labels:
        confusion_matrix[label_to_idx[query_label], label_to_idx[ret_label]] += 1

# Normalize confusion matrix
confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1, keepdims=True)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=np.unique(labels), yticklabels=np.unique(labels))
plt.title('Retrieval Confusion Matrix (Normalized)')
plt.ylabel('Query Cancer Type')
plt.xlabel('Retrieved Cancer Type')
plt.tight_layout()
plt.show()

## 7. Visualize Retrieval Examples

In [None]:
# Visualize specific retrieval examples
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

# Select a few query examples
n_examples = 3
query_indices = np.random.choice(len(labels), n_examples, replace=False)

fig, axes = plt.subplots(1, n_examples, figsize=(15, 5))
if n_examples == 1:
    axes = [axes]

for ax_idx, query_idx in enumerate(query_indices):
    # Get query and retrieved items
    retrieved_indices, distances = retriever.retrieve(query_idx, k=10)
    all_indices = np.concatenate([[query_idx], retrieved_indices])
    
    # Get embeddings for visualization
    vis_embeddings = embeddings[all_indices]
    
    # Reduce dimensionality
    if vis_embeddings.shape[1] > 50:
        pca = PCA(n_components=50, random_state=42)
        vis_embeddings = pca.fit_transform(vis_embeddings)
    
    tsne = TSNE(n_components=2, random_state=42)
    vis_2d = tsne.fit_transform(vis_embeddings)
    
    # Plot
    ax = axes[ax_idx]
    
    # Plot retrieved items
    colors = ['green' if labels[idx] == labels[query_idx] else 'red' 
              for idx in retrieved_indices]
    ax.scatter(vis_2d[1:, 0], vis_2d[1:, 1], c=colors, alpha=0.6, s=100)
    
    # Plot query
    ax.scatter(vis_2d[0, 0], vis_2d[0, 1], c='blue', s=200, 
              marker='*', edgecolors='black', linewidth=2)
    
    # Add labels
    ax.set_title(f'Query: {labels[query_idx]}')
    ax.set_xlabel('t-SNE 1')
    ax.set_ylabel('t-SNE 2')
    
    # Legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='*', color='w', markerfacecolor='blue', 
               markersize=10, label='Query'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='green', 
               markersize=10, label='Correct'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='red', 
               markersize=10, label='Incorrect')
    ]
    ax.legend(handles=legend_elements, loc='best')

plt.tight_layout()
plt.show()

## 8. Compare Different Similarity Metrics

In [None]:
# Compare different distance metrics
metrics = ['cosine', 'euclidean', 'manhattan']
metric_results = {}

for metric in metrics:
    print(f"\nEvaluating {metric} distance...")
    
    # Create retriever with specific metric
    metric_retriever = EmbeddingRetriever(embeddings, labels, metric=metric)
    
    # Calculate Precision@10
    mean_prec, _ = calculate_precision_at_k(metric_retriever, k_values=[10])
    metric_results[metric] = mean_prec[10]

# Plot comparison
plt.figure(figsize=(8, 6))
metrics_list = list(metric_results.keys())
precisions = list(metric_results.values())

bars = plt.bar(metrics_list, precisions)
plt.xlabel('Distance Metric')
plt.ylabel('Precision@10')
plt.title('Retrieval Performance by Distance Metric')
plt.ylim(0, 1)

# Add value labels
for bar, precision in zip(bars, precisions):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{precision:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## 9. Multi-Modal Retrieval

In [None]:
# Create mock multi-modal embeddings
modalities = ['clinical', 'pathology', 'radiology']
modality_embeddings = {}

for i, modality in enumerate(modalities):
    if modality == 'clinical':
        modality_embeddings[modality] = embeddings
    else:
        # Create correlated mock embeddings
        noise = np.random.randn(*embeddings.shape) * 0.5
        modality_embeddings[modality] = embeddings + noise

# Evaluate each modality
modality_results = {}

for modality, mod_embeddings in modality_embeddings.items():
    print(f"\nEvaluating {modality} modality...")
    mod_retriever = EmbeddingRetriever(mod_embeddings, labels)
    mean_prec, _ = calculate_precision_at_k(mod_retriever, k_values=[1, 5, 10, 20])
    modality_results[modality] = mean_prec

# Plot comparison
plt.figure(figsize=(10, 6))
k_values = sorted(list(modality_results.values())[0].keys())

for modality, results in modality_results.items():
    precisions = [results[k] for k in k_values]
    plt.plot(k_values, precisions, 'o-', label=modality, linewidth=2, markersize=8)

plt.xlabel('k')
plt.ylabel('Precision@k')
plt.title('Retrieval Performance Across Modalities')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 10. Advanced: Adjusted Mutual Information (AMI) Analysis

In [None]:
# Calculate AMI scores for different k values
def calculate_ami_scores(retriever, k_values=[5, 10, 20, 50]):
    """
    Calculate Adjusted Mutual Information scores
    """
    ami_scores = {}
    
    for k in k_values:
        print(f"Calculating AMI for k={k}...")
        ami_values = []
        
        for query_idx in range(len(retriever.labels)):
            # Get query cluster (label)
            query_label = retriever.labels[query_idx]
            
            # Get retrieved items
            retrieved_indices, _ = retriever.retrieve(query_idx, k=k)
            
            # Create cluster assignments
            true_labels = [query_label] + list(retriever.labels[retrieved_indices])
            pred_labels = [0] + [1] * k  # Query vs retrieved
            
            # Calculate AMI
            ami = adjusted_mutual_info_score(true_labels, pred_labels)
            ami_values.append(ami)
        
        ami_scores[k] = np.mean(ami_values)
    
    return ami_scores

# Calculate AMI scores
ami_scores = calculate_ami_scores(retriever)

# Plot AMI scores
plt.figure(figsize=(8, 6))
k_vals = sorted(ami_scores.keys())
scores = [ami_scores[k] for k in k_vals]

plt.plot(k_vals, scores, 'o-', linewidth=2, markersize=8)
plt.xlabel('k')
plt.ylabel('AMI Score')
plt.title('Adjusted Mutual Information vs k')
plt.grid(True, alpha=0.3)

for k, score in zip(k_vals, scores):
    plt.annotate(f'{score:.3f}', (k, score), textcoords="offset points", 
                xytext=(0,10), ha='center')

plt.tight_layout()
plt.show()

## Summary and Next Steps

In this workshop, you learned to:
1. ✅ Implement similarity-based retrieval systems
2. ✅ Evaluate retrieval with Precision@k metrics
3. ✅ Analyze retrieval failures and confusion patterns
4. ✅ Compare different distance metrics and modalities
5. ✅ Calculate AMI scores for clustering quality

**Next Workshop**: Part 6 - Survival Analysis

**Key Takeaways**:
- Embeddings enable effective similarity-based retrieval
- Cosine similarity often works best for normalized embeddings
- Retrieval performance varies by cancer type and modality
- Failure analysis reveals which cancer types are easily confused

**Exercise**: Try implementing re-ranking strategies or learning-to-rank approaches!