# Metrics Notebook
This is a simple notebook designed to run all the metrics for a given variant of CLIP.  It expects a model as defined in the config below that is "CLIP-like" in that it can take in an image or text and output an embedding of some size.  The clip-like model should specifically adhere to the Clip interface defined in `models/clip.py`

The four metrics we implement are outlined here: 

### Top-K Retrieval Accuracy.
Given an image, we compute its CLIP embedding and retrieve the closest K captions based on cosine similarity with caption embeddings. If the target caption is within the top K, we count this as a correct retrieval. This metric is a direct proxy for classification accuracy in multimodal retrieval. It is valuable because strong cross-modal alignment should yield high retrieval accuracy. However, since our approach aims to reduce sparsity on the hypersphere, the embeddings may become less linearly separable, potentially lowering retrieval performance even as uniformity improves.

### Modality Gap via Linear Separability.
Following \citet{modalityGAP}, we measure the modality gap between text and image embeddings by training a soft-margin SVM classifier to distinguish modality type. We evaluate classification accuracy, precision, and recall. High separability indicates a strong modality gap, which is undesirable because semantically matched image–text pairs should ideally share indistinguishable representations. Reducing modality separability would thus reflect improved multimodal coordination.

### Hyperspherical Entropy Estimation.
We measure the entropy of the embedding distribution on the hypersphere using the k-nearest neighbor–based estimator proposed by \citet{entropy}. This estimator leverages angular distances to compute local density estimates, which are aggregated into a global entropy measure. Entropy serves as a proxy for sparsity: low entropy distributions are clustered and “spiky,” while high entropy indicates more uniform coverage of the hypersphere. Since our method encourages uniformity, we expect an increase in entropy relative to standard CLIP.

### Downstream Captioning Performance (BLEU)
Finally, we evaluate the utility of embeddings on a generative downstream task: image captioning. Image embeddings are passed into a pretrained language model to generate captions, which are compared against ground-truth captions using BLEU score. BLEU measures n-gram overlap between generated and reference text, rewarding fluency and accuracy. This extrinsic metric demonstrates how improvements in embedding geometry translate into practical benefits for end-user tasks, beyond abstract geometric properties.

In [1]:
from models.clipModel import CLIPModel


model = CLIPModel()

In [None]:
import torch

def top_k_similarities(embeddings, query_embedding, k=5):
    """
    Compute the top-k most similar embeddings to the query_embedding.
    
    Args:
        embeddings (torch.Tensor): Tensor of shape (N, D) where N is the number of embeddings and D is the embedding dimension.
        query_embedding (torch.Tensor): Tensor of shape (D,) representing the query embedding.
        k (int): Number of top similar embeddings to return.

    Returns:
        List[Tuple[int, float]]: List of tuples containing the index and similarity score of the top-k most similar embeddings.
    """
    # Compute cosine similarities
    similarities = torch.nn.functional.cosine_similarity(embeddings, query_embedding.unsqueeze(0), dim=1)

    # Get top-k indices
    top_k_indices = similarities.topk(k).indices

    # Return list of (index, similarity) tuples
    return [(idx.item(), similarities[idx].item()) for idx in top_k_indices]

def top_k_score(embedding_pairs, k=5):
    """
    Given a list of (text_embedding[], image_embedding) pairs, return the percentage of texts that are in the top-k most similar to their corresponding image embeddings.
    """
    correct_count = 0
    for text_embeddings, image_embedding in embedding_pairs:
        top_k = top_k_similarities(text_embeddings, image_embedding, k)
        if 0 in [idx for idx, _ in top_k]:  # Assuming the correct text is always at index 0
            correct_count += 1
    return correct_count / len(embedding_pairs) if embedding_pairs else 0.0

In [None]:
import torch.nn as nn

def linear_separability(image_embeddings, text_embeddings, num_epochs=100, learning_rate=1e-3):
    """
    Train a linear classifier to distinguish between image and text embeddings, and report the accuracy.
    
    Args:
        image_embeddings (torch.Tensor): Tensor of shape (N, D) for image embeddings.
        text_embeddings (torch.Tensor): Tensor of shape (N, D) for text embeddings.
        num_epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for the optimizer.

    Returns:
        float: Accuracy of the classifier on the given set.
    """
    # Combine image and text embeddings
    embeddings = torch.cat([image_embeddings, text_embeddings], dim=0)
    labels = torch.cat([torch.zeros(image_embeddings.size(0)), torch.ones(text_embeddings.size(0))], dim=0)

    # Train a linear classifier
    classifier = nn.Linear(embeddings.size(1), 2)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        outputs = classifier(embeddings)
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()

    # Evaluate the classifier
    with torch.no_grad():
        outputs = classifier(embeddings)
        preds = outputs.argmax(dim=1)
        accuracy = (preds == labels).float().mean().item()

    return accuracy

In [None]:
def bleu_score(predictions, references):
    """
    Compute a simple BLEU score for a list of predictions and references.
    
    Args:
        predictions (List[str]): List of predicted sentences.
        references (List[str]): List of reference sentences.

    Returns:
        float: Average BLEU score across all predictions.
    """
    from nltk.translate.bleu_score import sentence_bleu

    total_score = 0.0
    for pred, ref in zip(predictions, references):
        ref_tokens = [ref.split()]
        pred_tokens = pred.split()
        score = sentence_bleu(ref_tokens, pred_tokens)
        total_score += score

    return total_score / len(predictions) if predictions else 0.0

In [None]:
def hypersphere_cap_area(phi, d):
        """
        Compute the area of a spherical cap with angle phi on a (d-1)-sphere.
        
        S(φ) = (1/2) * S_p * [1 - sgn(cos φ) * I_{cos²φ}(1/2, (p-1)/2)]
        
        where S_p is the surface area of the (d-1)-sphere and I is the regularized 
        incomplete beta function.
        """
        # Surface area of (d-1)-sphere: S_p = 2π^(d/2) / Γ(d/2)
        from scipy.special import gamma, betainc
        S_p = 2 * (np.pi ** (d/2)) / gamma(d/2)
        print(S_p)  # --- IGNORE ---
        
        cos_phi = np.cos(phi)
        cos_phi_squared = cos_phi ** 2
        cos_phi_sgn = np.sign(cos_phi)

        alpha = 0.5
        beta = (d - 1) / 2
        incomplete_beta = betainc(alpha, beta, cos_phi_squared)
        print((1 - cos_phi_sgn * incomplete_beta))
        cap_area = 0.5 * S_p * (1 - cos_phi_sgn * incomplete_beta)
        assert cap_area > 0, f"Cap area must be positive, not {cap_area}."
        return cap_area

In [104]:
from scipy.special import gamma, betainc

d = 128
2 * (np.pi ** (d/2)) / gamma(d/2)

np.float64(6.628037009147083e-56)

In [86]:
hypersphere_cap_area(.1, 3)

12.566370614359174
0.0049958347219741794


np.float64(0.03138975532220578)

In [6]:
import torch
import numpy as np
from scipy.special import digamma, beta
from scipy.spatial.distance import pdist, squareform

def knn_entropy_notgpt(embeddings, k=5):
    """
    Compute the k-nearest neighbor entropy estimator for hyperspherical data.
    
    This estimator is designed for data on a unit hypersphere and uses the 
    k-nearest neighbor approach to estimate entropy consistently.
    
    Args:
        embeddings (torch.Tensor): Tensor of shape (N, D) representing N embeddings on the unit hypersphere of dimension D-1.
        k (int): Number of nearest neighbors to consider.
    
    Returns:
        float: Estimated entropy of the distribution.
    """
    # Ensure embeddings are normalized to unit sphere
    embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)
    embeddings_np = embeddings.detach().cpu().numpy()
    
    n, d = embeddings_np.shape
    
    # Compute pairwise angular distances using arccos(x^T y)
    # Since embeddings are normalized, dot product gives cosine similarity
    dot_products = np.dot(embeddings_np, embeddings_np.T)
    # Clamp to avoid numerical issues with arccos
    dot_products = np.clip(dot_products, -1.0, 1.0)
    angular_distances = np.arccos(dot_products)
    
    # For each point, find the k-th nearest neighbor distance
    phi_values = []
    for i in range(n):
        # Get distances to all other points (excluding self)
        distances_to_i = angular_distances[i]
        distances_to_i = np.delete(distances_to_i, i)  # Remove self-distance (which is 0)
        
        # Sort and get k-th nearest neighbor distance
        sorted_distances = np.sort(distances_to_i)
        phi_i = sorted_distances[k-1]  # k-th nearest (0-indexed)
        phi_values.append(phi_i)

    phi_values = np.array(phi_values)
    print(phi_values)

    # Compute cap areas for all phi values
    S_phi = np.array([hypersphere_cap_area(phi, d) for phi in phi_values])
    # Compute L_{n,i} = ln(f_n(X_i)) = ln(k/n / S(phi_i))
    L_values = np.log(k/n) - np.log(S_phi)
    
    # Compute digamma function ψ(k)
    psi_k = digamma(k)
    
    # Compute entropy using the first formulation:
    # H_n(f) = -(1/n) * Σ[L_{n,i} - ln(k) + ψ(k)]
    entropy = -(1/n) * np.sum(L_values - np.log(k) + psi_k)
    
    return entropy


In [10]:
import numpy as np
from scipy.special import digamma, gamma
from sklearn.neighbors import NearestNeighbors

def knn_entropy(X, k=4):
    print("X shape: ", X.shape)
    if isinstance(X, torch.Tensor):
        X = X.detach().cpu().numpy()
    X /= np.linalg.norm(X, axis=1, keepdims=True)
    # X: (N, D) array, rows are unit vectors
    N, D = X.shape
    d = D - 1  # intrinsic dimension of S^{D-1}

    # compute cosine similarities and clip for numerical safety
    sims = X.dot(X.T)
    np.fill_diagonal(sims, 1.0)
    sims = np.clip(sims, -1.0, 1.0)

    # geodesic distances (great-circle)
    ang = np.arccos(sims)   # matrix of angles in [0, pi]
    print("hello")
    # use NearestNeighbors on precomputed distances
    nbrs = NearestNeighbors(n_neighbors=k+1, metric='precomputed')
    nbrs.fit(ang)  # note: fit expects (N, N) precomputed distance matrix
    print("hello2")
    distances, indices = nbrs.kneighbors(ang)
    # distances[:,0] is zero (self); k-th neighbor is distances[:, k]
    r_k = distances[:, k]   # geodesic radius to k-th neighbor

    eps = 2.0 * r_k  # diameter as in KL notation

    print("hello3")
    # volume constant c_d for Euclidean d-ball but note factor 2^d in some conventions
    c_d = (np.pi**(d/2)) / (gamma(1 + d/2) * (2.0**d))

    H_hat = -digamma(k) + digamma(N) + np.log(c_d) + (d / N) * np.sum(np.log(eps + 1e-12))
    print(-digamma(k), digamma(N), np.log(c_d), (d / N) * np.sum(np.log(eps + 1e-12)))
    return H_hat  # in nats 

In [11]:
# Test knn entropy
X_test = np.random.randn(300, 128)
 # normalize to unit sphere
entropy_estimate = knn_entropy(X_test, k=5)
print("Estimated entropy:", entropy_estimate)

X shape:  (300, 128)
hello
hello2
hello3
-1.5061176684318003 5.702114882064637 -218.42614901550428 128.84635272811022
Estimated entropy: -85.38379907376122


In [12]:
# Alternative implementation using the second formulation for verification
def knn_entropy_alternative(embeddings, k=5):
    """
    Alternative implementation using the second formulation:
    H_n(f) = (1/n) * Σ ln[n * S(φ_i)] - ψ(k)
    """
    # Ensure embeddings are normalized to unit sphere
    embeddings = embeddings / torch.norm(embeddings, dim=1, keepdim=True)
    embeddings_np = embeddings.detach().cpu().numpy()
    
    n, d = embeddings_np.shape
    
    # Compute pairwise angular distances
    dot_products = np.dot(embeddings_np, embeddings_np.T)
    dot_products = np.clip(dot_products, -1.0, 1.0)
    angular_distances = np.arccos(dot_products)
    
    # Find k-th nearest neighbor distances
    phi_values = []
    for i in range(n):
        distances_to_i = angular_distances[i]
        distances_to_i = np.delete(distances_to_i, i)
        sorted_distances = np.sort(distances_to_i)
        phi_i = sorted_distances[k-1]
        phi_values.append(phi_i)
    
    phi_values = np.array(phi_values)
    
    # Compute cap areas using the same simplified geometric formula
    def hypersphere_cap_area(phi, d):
        """
        Compute the area of a spherical cap with angular radius phi on a unit (d-1)-sphere.
        This is the same simplified implementation as in the main function.
        """
        from scipy.special import gamma
        
        # Surface area of the full (d-1)-sphere: S_d = 2π^(d/2) / Γ(d/2)
        S_full = 2 * (np.pi ** (d/2)) / gamma(d/2)
        
        # For small angles, use series expansion to avoid numerical issues
        if phi < 1e-10:
            cap_area = (np.pi ** ((d-1)/2) / gamma((d+1)/2)) * (phi ** (d-1))
        elif phi >= np.pi:
            cap_area = S_full
        else:
            if d == 2:
                # Special case for 2D (circle): cap area = 2 * phi
                cap_area = 2 * phi
            elif d == 3:
                # Special case for 3D (sphere): cap area = 2π * (1 - cos(phi))
                cap_area = 2 * np.pi * (1 - np.cos(phi))
            else:
                # General case: use numerical integration
                from scipy.integrate import quad
                
                def integrand(theta):
                    return np.sin(theta) ** (d - 2)
                
                numerator, _ = quad(integrand, 0, phi)
                denominator, _ = quad(integrand, 0, np.pi)
                cap_area = S_full * (numerator / denominator)
        
        return max(cap_area, 1e-15)
    
    S_phi = np.array([hypersphere_cap_area(phi, d) for phi in phi_values])
    
    # Second formulation: H_n(f) = (1/n) * Σ ln[n * S(φ_i)] - ψ(k)
    psi_k = digamma(k)
    entropy = (1/n) * np.sum(np.log(n * S_phi)) - psi_k
    
    return entropy

In [None]:
# Test the hyperspherical entropy estimator on COCO data
import torch
import numpy as np
from datasetLoader import DatasetLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

def test_entropy_estimator_on_coco():
    """
    Test the hyperspherical entropy estimator on COCO embeddings.
    This will:
    1. Load COCO dataset samples
    2. Generate CLIP embeddings for images and text
    3. Compute entropy for both modalities
    4. Compare with random uniform embeddings as a baseline
    """
    print("Testing Hyperspherical Entropy Estimator on COCO Data")
    print("=" * 60)
    
    # Load COCO dataset samples
    try:
        print("Loading COCO dataset...")
        data_samples = DatasetLoader.load_coco_dataset(
            data_dir="data",
            split="val2017",
            max_samples=500  # Use smaller sample for testing
        )
        print(f"Loaded {len(data_samples)} COCO samples")
    except Exception as e:
        print(f"Could not load COCO dataset: {e}")
        print("Generating synthetic data for testing...")
        # Create synthetic data for testing
        data_samples = []
        for i in range(100):
            data_samples.append({
                "image_path": f"dummy_image_{i}.jpg",
                "text": f"A sample caption number {i}"
            })
    
    # Generate embeddings using the CLIP model
    print("\nGenerating CLIP embeddings...")
    
    # For testing, we'll use a subset to make it faster
    test_samples = data_samples[:100]
    
    try:
        # Generate image embeddings
        print("Generating image embeddings...")
        image_paths = [sample["image_path"] for sample in test_samples]
        image_embeddings = []
        
        # Process in smaller batches to avoid memory issues
        batch_size = 10
        for i in tqdm(range(0, len(image_paths), batch_size)):
            batch_paths = image_paths[i:i+batch_size]
            try:
                batch_embeddings = model.encode_images(batch_paths)
                image_embeddings.extend(batch_embeddings)
            except Exception as e:
                print(f"Error encoding batch {i//batch_size}: {e}")
                # Generate random embeddings as fallback
                for _ in batch_paths:
                    random_emb = torch.randn(512)  # Assume 512-dim embeddings
                    random_emb = random_emb / torch.norm(random_emb)  # Normalize to unit sphere
                    image_embeddings.append(random_emb)
        
        image_embeddings = torch.stack(image_embeddings)
        
        # Generate text embeddings
        print("Generating text embeddings...")
        texts = [sample["text"] for sample in test_samples]
        text_embeddings = []
        
        for i in tqdm(range(0, len(texts), batch_size)):
            batch_texts = texts[i:i+batch_size]
            try:
                batch_embeddings = model.encode_text(batch_texts)
                text_embeddings.extend(batch_embeddings)
            except Exception as e:
                print(f"Error encoding text batch {i//batch_size}: {e}")
                # Generate random embeddings as fallback
                for _ in batch_texts:
                    random_emb = torch.randn(512)
                    random_emb = random_emb / torch.norm(random_emb)
                    text_embeddings.append(random_emb)
        
        text_embeddings = torch.stack(text_embeddings)
        
    except Exception as e:
        print(f"Error generating embeddings: {e}")
        print("Using synthetic embeddings for testing...")
        
        # Generate synthetic embeddings for testing
        embedding_dim = 512
        num_samples = len(test_samples)
        
        # Generate clustered embeddings (lower entropy)
        image_embeddings = torch.randn(num_samples, embedding_dim)
        image_embeddings = image_embeddings / torch.norm(image_embeddings, dim=1, keepdim=True)
        
        text_embeddings = torch.randn(num_samples, embedding_dim)
        text_embeddings = text_embeddings / torch.norm(text_embeddings, dim=1, keepdim=True)
    
    print(f"\nImage embeddings shape: {image_embeddings.shape}")
    print(f"Text embeddings shape: {text_embeddings.shape}")
    
    # Test entropy estimation with different k values
    k_values = [3, 5, 7, 10, 15]
    
    print("\nComputing entropy estimates...")
    results = {
        'k_values': k_values,
        'image_entropy': [],
        'text_entropy': [],
        'image_entropy_alt': [],
        'text_entropy_alt': []
    }
    
    for k in k_values:
        print(f"\nTesting with k={k}:")
        
        # Compute entropy for image embeddings
        try:
            img_entropy = knn_entropy(image_embeddings, k=k)
            print(f"  Image entropy (method 1): {img_entropy:.4f}")
            
            results['image_entropy'].append(img_entropy)
        except Exception as e:
            print(f"  Error computing image entropy: {e}")
            results['image_entropy'].append(np.nan)
            results['image_entropy_alt'].append(np.nan)
        
        # Compute entropy for text embeddings
        try:
            txt_entropy = knn_entropy(text_embeddings, k=k)
            print(f"  Text entropy (method 1): {txt_entropy:.4f}")
            
            results['text_entropy'].append(txt_entropy)
        except Exception as e:
            print(f"  Error computing text entropy: {e}")
            results['text_entropy'].append(np.nan)
            results['text_entropy_alt'].append(np.nan)
    
    # Generate baseline: uniform random embeddings on hypersphere
    print("\nGenerating uniform random baseline...")
    embedding_dim = image_embeddings.shape[1]
    num_samples = image_embeddings.shape[0]
    
    # Generate uniform random embeddings on hypersphere
    random_embeddings = torch.randn(num_samples, embedding_dim)
    random_embeddings = random_embeddings / torch.norm(random_embeddings, dim=1, keepdim=True)
    
    print("Computing entropy for uniform random embeddings...")
    random_entropies = []
    for k in k_values:
        try:
            random_entropy = knn_entropy(random_embeddings, k=k)
            random_entropies.append(random_entropy)
            print(f"  Random entropy (k={k}): {random_entropy:.4f}")
        except Exception as e:
            print(f"  Error computing random entropy (k={k}): {e}")
            random_entropies.append(np.nan)
    
    results['random_entropy'] = random_entropies
    
    # Plot results
    plt.figure(figsize=(15, 5))
    
    # Plot 1: Entropy vs k for different data types
    plt.subplot(1, 3, 1)
    plt.plot(k_values, results['image_entropy'], 'bo-', label='Image Embeddings', markersize=6)
    plt.plot(k_values, results['text_entropy'], 'ro-', label='Text Embeddings', markersize=6)
    plt.plot(k_values, results['random_entropy'], 'go-', label='Random Uniform', markersize=6)
    plt.xlabel('k (number of neighbors)')
    plt.ylabel('Entropy Estimate')
    plt.title('Hyperspherical Entropy vs k')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 2: Comparison of two estimation methods
    plt.subplot(1, 3, 2)
    plt.plot(k_values, results['image_entropy'], 'b-', label='Image (Method 1)', linewidth=2)
    plt.plot(k_values, results['text_entropy'], 'r-', label='Text (Method 1)', linewidth=2)
    plt.xlabel('k (number of neighbors)')
    plt.ylabel('Entropy Estimate')
    plt.title('Comparison of Estimation Methods')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 3: Histogram of embedding norms (should be close to 1)
    plt.subplot(1, 3, 3)
    img_norms = torch.norm(image_embeddings, dim=1).detach().numpy()
    txt_norms = torch.norm(text_embeddings, dim=1).detach().numpy()
    plt.hist(img_norms, bins=20, alpha=0.6, label='Image Embeddings', density=True)
    plt.hist(txt_norms, bins=20, alpha=0.6, label='Text Embeddings', density=True)
    plt.axvline(x=1.0, color='red', linestyle='--', label='Unit Norm')
    plt.xlabel('Embedding Norm')
    plt.ylabel('Density')
    plt.title('Distribution of Embedding Norms')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Summary statistics
    print("\n" + "="*60)
    print("ENTROPY ESTIMATION SUMMARY")
    print("="*60)
    
    # Average entropy across k values (excluding NaN)
    img_avg = np.nanmean(results['image_entropy'])
    txt_avg = np.nanmean(results['text_entropy'])
    rand_avg = np.nanmean(results['random_entropy'])
    
    print(f"Average Entropy Estimates:")
    print(f"  Image Embeddings: {img_avg:.4f}")
    print(f"  Text Embeddings:  {txt_avg:.4f}")
    print(f"  Random Uniform:   {rand_avg:.4f}")
    
    # Check if embeddings are properly normalized
    img_norm_mean = torch.norm(image_embeddings, dim=1).mean().item()
    txt_norm_mean = torch.norm(text_embeddings, dim=1).mean().item()
    print(f"\nEmbedding Normalization Check:")
    print(f"  Image embeddings mean norm: {img_norm_mean:.6f}")
    print(f"  Text embeddings mean norm:  {txt_norm_mean:.6f}")
    print(f"  (Should be close to 1.0 for unit hypersphere)")
    
    # Interpretation
    print(f"\nInterpretation:")
    if rand_avg > max(img_avg, txt_avg):
        print("✓ Random uniform embeddings have higher entropy than CLIP embeddings")
        print("  This suggests CLIP embeddings are more clustered/structured")
    else:
        print("! CLIP embeddings have higher entropy than uniform random")
        print("  This might indicate issues with the estimation or data")
    
    if abs(img_avg - txt_avg) > 0.5:
        modality_gap = "large"
    elif abs(img_avg - txt_avg) > 0.2:
        modality_gap = "moderate"
    else:
        modality_gap = "small"
    
    print(f"  Modality gap in entropy: {modality_gap} ({abs(img_avg - txt_avg):.3f})")
    
    return results

# Run the test
test_results = test_entropy_estimator_on_coco()

2025-10-20 18:31:28,838 - INFO - Loading COCO val2017 dataset...
2025-10-20 18:31:28,902 - INFO - Loaded 500 COCO samples


Testing Hyperspherical Entropy Estimator on COCO Data
Loading COCO dataset...
Loaded 500 COCO samples

Generating CLIP embeddings...
Generating image embeddings...


100%|██████████| 10/10 [00:02<00:00,  3.46it/s]


Generating text embeddings...


100%|██████████| 10/10 [00:02<00:00,  3.78it/s]
  H_hat = -digamma(k) + digamma(N) + np.log(c_d) + (d / N) * np.sum(np.log(eps + 1e-12))
  print(-digamma(k), digamma(N), np.log(c_d), (d / N) * np.sum(np.log(eps + 1e-12)))
  H_hat = -digamma(k) + digamma(N) + np.log(c_d) + (d / N) * np.sum(np.log(eps + 1e-12))
  print(-digamma(k), digamma(N), np.log(c_d), (d / N) * np.sum(np.log(eps + 1e-12)))
  H_hat = -digamma(k) + digamma(N) + np.log(c_d) + (d / N) * np.sum(np.log(eps + 1e-12))
  print(-digamma(k), digamma(N), np.log(c_d), (d / N) * np.sum(np.log(eps + 1e-12)))
  H_hat = -digamma(k) + digamma(N) + np.log(c_d) + (d / N) * np.sum(np.log(eps + 1e-12))
  print(-digamma(k), digamma(N), np.log(c_d), (d / N) * np.sum(np.log(eps + 1e-12)))
  H_hat = -digamma(k) + digamma(N) + np.log(c_d) + (d / N) * np.sum(np.log(eps + 1e-12))
  print(-digamma(k), digamma(N), np.log(c_d), (d / N) * np.sum(np.log(eps + 1e-12)))
  H_hat = -digamma(k) + digamma(N) + np.log(c_d) + (d / N) * np.sum(np.log(eps + 1


Image embeddings shape: torch.Size([100, 512])
Text embeddings shape: torch.Size([100, 512])

Computing entropy estimates...

Testing with k=3:
X shape:  torch.Size([100, 512])
hello
hello2
hello3
-0.9227843350984671 4.600161852738088 -inf -5803.1406
  Image entropy (method 1): -inf
X shape:  torch.Size([100, 512])
hello
hello2
hello3
-0.9227843350984671 4.600161852738088 -inf 149.63445
  Text entropy (method 1): -inf

Testing with k=5:
X shape:  torch.Size([100, 512])
hello
hello2
hello3
-1.5061176684318003 4.600161852738088 -inf 241.86328
  Image entropy (method 1): -inf
X shape:  torch.Size([100, 512])
hello
hello2
hello3
-1.5061176684318003 4.600161852738088 -inf 223.04105
  Text entropy (method 1): -inf

Testing with k=7:
X shape:  torch.Size([100, 512])
hello
hello2
hello3
-1.872784335098467 4.600161852738088 -inf 261.00137
  Image entropy (method 1): -inf
X shape:  torch.Size([100, 512])
hello
hello2
hello3
-1.872784335098467 4.600161852738088 -inf 248.91978
  Text entropy (met

In [None]:
def validate_entropy_estimator():
    """
    Validate the entropy estimator with synthetic data where we can control the distribution properties.
    """
    print("Validating Entropy Estimator with Synthetic Data")
    print("=" * 50)
    
    # Test parameters
    embedding_dim = 128
    num_samples = 300
    k = 30
    
    # Test 1: Uniform random distribution on hypersphere (should have high entropy)
    print("\nTest 1: Uniform Random Distribution")
    uniform_embeddings = torch.randn(num_samples, embedding_dim)
    uniform_embeddings = uniform_embeddings / torch.norm(uniform_embeddings, dim=1, keepdim=True)

    uniform_entropy_1 = knn_entropy(uniform_embeddings, k=k)
    
    print(f"  Uniform entropy (method 1): {uniform_entropy_1:.4f}")
    
    # Test 2: Clustered distribution (should have lower entropy)
    print("\nTest 2: Clustered Distribution (3 clusters)")
    clustered_embeddings = []
    samples_per_cluster = num_samples // 3
    
    # Create 3 clusters
    cluster_centers = torch.randn(3, embedding_dim)
    cluster_centers = cluster_centers / torch.norm(cluster_centers, dim=1, keepdim=True)
    
    for i in range(3):
        # Generate points around each cluster center
        cluster_points = cluster_centers[i].unsqueeze(0) + 0.01 * torch.randn(samples_per_cluster, embedding_dim)
        cluster_points = cluster_points / torch.norm(cluster_points, dim=1, keepdim=True)
        clustered_embeddings.append(cluster_points)
    
    clustered_embeddings = torch.cat(clustered_embeddings, dim=0)
    clustered_entropy_1 = knn_entropy(clustered_embeddings, k=k)
    
    print(f"  Clustered entropy (method 1): {clustered_entropy_1:.4f}")
    
    # Test 3: Very concentrated distribution (should have very low entropy)
    print("\nTest 3: Highly Concentrated Distribution")
    concentrated_embeddings = torch.randn(1, embedding_dim) + 0.01 * torch.randn(num_samples, embedding_dim)  # Small variance
    concentrated_embeddings = concentrated_embeddings / torch.norm(concentrated_embeddings, dim=1, keepdim=True)
    
    concentrated_entropy_1 = knn_entropy(concentrated_embeddings, k=k)
    
    print(f"  Concentrated entropy (method 1): {concentrated_entropy_1:.4f}")
    
    # Visualization
    plt.figure(figsize=(12, 8))
    
    # Plot entropy comparison
    plt.subplot(2, 2, 1)
    distributions = ['Uniform', 'Clustered', 'Concentrated']
    entropies_1 = [uniform_entropy_1, clustered_entropy_1, concentrated_entropy_1]
    
    x = np.arange(len(distributions))
    width = 0.35
    
    plt.bar(x - width/2, entropies_1, width, label='Method 1', alpha=0.8)
    plt.xlabel('Distribution Type')
    plt.ylabel('Entropy Estimate')
    plt.title('Entropy Estimates for Different Distributions')
    plt.xticks(x, distributions)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot 2D projections to visualize distributions (for first 2 dimensions)
    plt.subplot(2, 2, 2)
    plt.scatter(uniform_embeddings[:, 0], uniform_embeddings[:, 1], alpha=0.6, s=20, label='Uniform')
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.title('Uniform Distribution (2D Projection)')
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 2, 3)
    colors = ['red', 'blue', 'green']
    for i in range(3):
        start_idx = i * samples_per_cluster
        end_idx = (i + 1) * samples_per_cluster
        plt.scatter(clustered_embeddings[start_idx:end_idx, 0], 
                   clustered_embeddings[start_idx:end_idx, 1], 
                   alpha=0.6, s=20, c=colors[i], label=f'Cluster {i+1}')
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.title('Clustered Distribution (2D Projection)')
    plt.legend()
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(2, 2, 4)
    plt.scatter(concentrated_embeddings[:, 0], concentrated_embeddings[:, 1], alpha=0.6, s=20, color='purple')
    plt.xlabel('Dimension 1')
    plt.ylabel('Dimension 2')
    plt.title('Concentrated Distribution (2D Projection)')
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Summary and validation
    print("\n" + "="*50)
    print("VALIDATION SUMMARY")
    print("="*50)
    
    # Check if entropy ordering makes sense
    entropy_ordering_correct = uniform_entropy_1 > clustered_entropy_1 > concentrated_entropy_1
    
    print(f"Expected entropy ordering (Uniform > Clustered > Concentrated): {'✓' if entropy_ordering_correct else '✗'}")
    
    if entropy_ordering_correct:
        print("✓ Entropy estimator validation PASSED")
        print("  The estimator correctly ranks distributions by their expected entropy levels")
    else:
        print("✗ Entropy estimator validation FAILED")
        print("  Check implementation or increase sample size")
    
    return {
        'uniform_entropy': (uniform_entropy_1),
        'clustered_entropy': (clustered_entropy_1),
        'concentrated_entropy': (concentrated_entropy_1),
        'validation_passed': entropy_ordering_correct
    }

# Run validation
validation_results = validate_entropy_estimator()

Validating Entropy Estimator with Synthetic Data

Test 1: Uniform Random Distribution
X shape:  torch.Size([300, 128])


  X /= np.linalg.norm(X, axis=1, keepdims=True)


RuntimeError: 1D tensors expected, but got 2D and 2D tensors

## CLIP Captioning Model Training

The code above provides a complete training pipeline for CLIP-based image captioning using GPT-2. Here's how to use it:

### Key Components:

1. **CaptionTrainingDataset**: Custom dataset that uses any CLIP model adhering to the `ClipModel` interface to generate image embeddings paired with captions.

2. **train_clip_captioning_model()**: Main training function that:
   - Sets up the ClipCaptionModel with proper dimensions based on the CLIP model
   - Trains using cross-entropy loss on next-token prediction
   - Evaluates using BLEU score on generated captions
   - Saves checkpoints during training

3. **evaluate_captioning_model()**: Evaluation function that computes both loss and BLEU score

4. **generate_caption()**: Caption generation function for inference

### Usage Example:

```python
# Train a captioning model using the loaded CLIP model
trained_model, training_losses = train_caption_model_on_coco(
    clip_model=model,  # Your CLIP model from above
    data_dir="data",   # Directory containing COCO dataset
    max_samples=1000,  # Number of samples to use
    batch_size=8,      # Adjust based on GPU memory
    num_epochs=5       # Number of training epochs
)

# Generate a caption for a new image
image_embedding = model.encode_images(["path/to/image.jpg"])
caption = generate_caption(trained_model, image_embedding, tokenizer, device="cuda")
print(f"Generated caption: {caption}")
```

### BLEU Score Evaluation:

The training pipeline automatically evaluates the model using BLEU score, which measures n-gram overlap between generated and reference captions. This provides a quantitative measure of caption quality that complements the training loss.

In [None]:
from models.clipModel import CLIPModel

model = CLIPModel()

In [None]:
# Import necessary libraries for captioning model training
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
import json
from datasetLoader import DatasetLoader
from models.clipCaptionModel import ClipCaptionModel
from typing import List, Tuple, Optional
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class CaptionTrainingDataset(Dataset):
    """Dataset for training CLIP captioning model."""
    
    def __init__(self, data_samples, clip_model, tokenizer, max_length=77):
        """
        Args:
            data_samples: List of data samples from DatasetLoader
            clip_model: CLIP model for generating image embeddings
            tokenizer: GPT-2 tokenizer
            max_length: Maximum sequence length for tokenization
        """
        self.data_samples = data_samples
        self.clip_model = clip_model
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Pre-compute image embeddings to avoid recomputing during training
        self.image_embeddings = self._precompute_image_embeddings()
        
    def _precompute_image_embeddings(self):
        """Pre-compute image embeddings for all samples."""
        print("Pre-computing image embeddings...")
        image_paths = [sample["image_path"] for sample in self.data_samples]
        
        # Generate embeddings in batches to save memory
        embeddings = []
        batch_size = 32
        
        for i in tqdm(range(0, len(image_paths), batch_size)):
            batch_paths = image_paths[i:i+batch_size]
            batch_embeddings = self.clip_model.encode_images(batch_paths)
            embeddings.extend(batch_embeddings)
        
        return torch.stack(embeddings)
    
    def __len__(self):
        return len(self.data_samples)
    
    def __getitem__(self, idx):
        sample = self.data_samples[idx]
        image_embedding = self.image_embeddings[idx]
        caption = sample["text"]
        
        # Tokenize caption
        tokens = self.tokenizer.encode(caption, max_length=self.max_length, 
                                     truncation=True, padding='max_length')
        tokens = torch.tensor(tokens, dtype=torch.long)
        
        return {
            'image_embedding': image_embedding,
            'tokens': tokens,
            'caption': caption
        }

Using device: cpu


In [None]:
def train_caption_model(clip_model, data_samples, num_epochs=5, batch_size=8, 
                       learning_rate=2e-5, save_path="clip_caption_model.pt"):
    """
    Train a CLIP captioning model using the provided CLIP model and dataset.
    
    Args:
        clip_model: Pre-trained CLIP model for generating image embeddings
        data_samples: List of data samples with image paths and captions
        num_epochs: Number of training epochs
        batch_size: Training batch size
        learning_rate: Learning rate for optimizer
        save_path: Path to save the trained model
    
    Returns:
        Tuple of (trained_model, training_losses, bleu_scores)
    """
    # Initialize tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    
    # Create dataset and dataloader
    dataset = CaptionTrainingDataset(data_samples, clip_model, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Initialize captioning model
    # Determine CLIP embedding size from the model
    sample_embedding = clip_model.encode_images([data_samples[0]["image_path"]])[0]
    clip_embed_size = sample_embedding.shape[0]
    
    caption_model = ClipCaptionModel(
        prefix_length=10,
        clip_length=10, 
        prefix_size=clip_embed_size
    ).to(device)
    
    # Initialize optimizer and scheduler
    optimizer = AdamW(caption_model.parameters(), lr=learning_rate)
    total_steps = len(dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=total_steps
    )
    
    # Training loop
    training_losses = []
    bleu_scores = []
    
    caption_model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch in progress_bar:
            image_embeddings = batch['image_embedding'].to(device)
            tokens = batch['tokens'].to(device)
            
            # Prepare input tokens (exclude last token) and labels (exclude first token)
            input_tokens = tokens[:, :-1]
            labels = tokens[:, 1:]
            
            # Forward pass
            outputs = caption_model(input_tokens, image_embeddings, labels=labels)
            loss = outputs.loss
            
            # Backward pass
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            epoch_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
        
        avg_epoch_loss = epoch_loss / len(dataloader)
        training_losses.append(avg_epoch_loss)
        
        # Evaluate with BLEU score every epoch
        bleu_score = evaluate_captioning_model(caption_model, clip_model, data_samples[:50], tokenizer)
        bleu_scores.append(bleu_score)
        
        print(f"Epoch {epoch+1}: Loss = {avg_epoch_loss:.4f}, BLEU = {bleu_score:.4f}")
    
    # Save the trained model
    torch.save(caption_model.state_dict(), save_path)
    print(f"Model saved to {save_path}")
    
    return caption_model, training_losses, bleu_scores

In [None]:
def evaluate_captioning_model(caption_model, clip_model, data_samples, tokenizer, max_samples=50):
    """
    Evaluate the captioning model using BLEU score.
    
    Args:
        caption_model: Trained captioning model
        clip_model: CLIP model for generating image embeddings
        data_samples: List of data samples for evaluation
        tokenizer: GPT-2 tokenizer
        max_samples: Maximum number of samples to evaluate
    
    Returns:
        Average BLEU score
    """
    caption_model.eval()
    
    # Limit samples for faster evaluation
    eval_samples = data_samples[:max_samples]
    
    predictions = []
    references = []
    
    with torch.no_grad():
        for sample in tqdm(eval_samples, desc="Evaluating"):
            # Generate image embedding
            image_embedding = clip_model.encode_images([sample["image_path"]])[0].unsqueeze(0).to(device)
            
            # Generate caption
            generated_caption = generate_caption(caption_model, image_embedding, tokenizer)
            
            predictions.append(generated_caption)
            references.append(sample["text"])
    
    # Calculate BLEU score using the function defined earlier
    bleu_score_result = bleu_score(predictions, references)
    return bleu_score_result

def generate_caption(caption_model, image_embedding, tokenizer, max_length=50, temperature=0.8):
    """
    Generate a caption for a given image embedding.
    
    Args:
        caption_model: Trained captioning model
        image_embedding: Image embedding from CLIP model
        tokenizer: GPT-2 tokenizer
        max_length: Maximum caption length
        temperature: Sampling temperature
    
    Returns:
        Generated caption as string
    """
    caption_model.eval()
    
    with torch.no_grad():
        # Start with just the image embedding
        generated_ids = []
        
        # Get dummy tokens for prefix
        dummy_tokens = caption_model.get_dummy_token(1, device)
        
        for _ in range(max_length):
            # Prepare input tokens
            if len(generated_ids) == 0:
                input_tokens = dummy_tokens
            else:
                input_tokens = torch.tensor([generated_ids], dtype=torch.long, device=device)
            
            # Forward pass
            outputs = caption_model(input_tokens, image_embedding)
            logits = outputs.logits
            
            # Get next token probabilities
            next_token_logits = logits[0, -1, :] / temperature
            next_token_probs = torch.softmax(next_token_logits, dim=-1)
            
            # Sample next token
            next_token = torch.multinomial(next_token_probs, 1).item()
            
            # Check for end token
            if next_token == tokenizer.eos_token_id:
                break
                
            generated_ids.append(next_token)
        
        # Decode the generated caption
        caption = tokenizer.decode(generated_ids, skip_special_tokens=True)
        return caption

In [None]:
def train_caption_model_on_coco(clip_model, data_dir="data", max_samples=1000, 
                                batch_size=8, num_epochs=5, split="train2017"):
    """
    Complete pipeline to train a captioning model on COCO dataset.
    
    Args:
        clip_model: Pre-trained CLIP model
        data_dir: Directory containing COCO dataset
        max_samples: Maximum number of training samples
        batch_size: Training batch size
        num_epochs: Number of training epochs
        split: COCO split to use ("train2017" or "val2017")
    
    Returns:
        Tuple of (trained_model, training_losses, bleu_scores)
    """
    # Load COCO dataset
    print("Loading COCO dataset...")
    data_samples = DatasetLoader.load_coco_dataset(
        data_dir=data_dir,
        split=split,
        max_samples=max_samples
    )
    
    print(f"Loaded {len(data_samples)} samples")
    
    # Split data into train and validation
    train_size = int(0.8 * len(data_samples))
    train_samples = data_samples[:train_size]
    val_samples = data_samples[train_size:]
    
    print(f"Training samples: {len(train_samples)}")
    print(f"Validation samples: {len(val_samples)}")
    
    # Train the model
    trained_model, training_losses, bleu_scores = train_caption_model(
        clip_model=clip_model,
        data_samples=train_samples,
        num_epochs=num_epochs,
        batch_size=batch_size
    )
    
    return trained_model, training_losses, bleu_scores, val_samples

def plot_training_metrics(training_losses, bleu_scores):
    """Plot training loss and BLEU scores over epochs."""
    import matplotlib.pyplot as plt
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Plot training loss
    ax1.plot(training_losses, 'b-', label='Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss Over Epochs')
    ax1.legend()
    ax1.grid(True)
    
    # Plot BLEU scores
    ax2.plot(bleu_scores, 'r-', label='BLEU Score')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('BLEU Score')
    ax2.set_title('BLEU Score Over Epochs')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

def test_caption_generation(trained_model, clip_model, val_samples, num_samples=5):
    """
    Test caption generation on validation samples.
    
    Args:
        trained_model: Trained captioning model
        clip_model: CLIP model
        val_samples: Validation samples
        num_samples: Number of samples to test
    """
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    
    print("Testing caption generation:")
    print("=" * 50)
    
    for i in range(min(num_samples, len(val_samples))):
        sample = val_samples[i]
        
        # Generate image embedding
        image_embedding = clip_model.encode_images([sample["image_path"]])[0].unsqueeze(0).to(device)
        
        # Generate caption
        generated_caption = generate_caption(trained_model, image_embedding, tokenizer)
        
        print(f"Sample {i+1}:")
        print(f"  Ground truth: {sample['text']}")
        print(f"  Generated:    {generated_caption}")
        print("-" * 30)

## Fine-tuning CLIP Captioning Model

Now we'll fine-tune a captioning model using the CLIP model defined above. The fine-tuning process includes:

1. **Dataset Loading**: Load COCO dataset with image-caption pairs
2. **Model Architecture**: Use ClipCaptionModel that takes CLIP image embeddings and generates captions using GPT-2
3. **Training Loop**: Train with cross-entropy loss on next-token prediction
4. **BLEU Score Evaluation**: Evaluate generated captions against ground truth using BLEU score
5. **Visualization**: Plot training metrics and test generation quality

### Key Features:

- **Pre-computed Embeddings**: Image embeddings are computed once and cached to speed up training
- **Batch Processing**: Efficient batching for both training and evaluation
- **BLEU Score Tracking**: Monitor caption quality improvement during training
- **Temperature Sampling**: Controlled text generation with adjustable creativity
- **Train/Val Split**: Proper evaluation on held-out validation set

### Usage:

In [None]:
# Example: Train a captioning model using the CLIP model defined above
print("Starting CLIP Captioning Model Fine-tuning...")
print(f"CLIP Model Type: {type(model)}")

# Train the captioning model
try:
    trained_model, losses, bleu_scores, val_samples = train_caption_model_on_coco(
        clip_model=model,
        data_dir="data",
        max_samples=500,  # Adjust based on your dataset size and computational resources
        batch_size=4,     # Adjust based on GPU memory
        num_epochs=3,     # Start with fewer epochs for testing
        split="train2017"
    )
    
    print("\nTraining completed successfully!")
    
    # Plot training metrics
    print("\nPlotting training metrics...")
    plot_training_metrics(losses, bleu_scores)
    
    # Test caption generation on validation samples
    print("\nTesting caption generation on validation samples...")
    test_caption_generation(trained_model, model, val_samples, num_samples=5)
    
    # Final evaluation on validation set
    print("\nFinal evaluation on validation set...")
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    final_bleu = evaluate_captioning_model(trained_model, model, val_samples, tokenizer, max_samples=50)
    print(f"Final Validation BLEU Score: {final_bleu:.4f}")
    
except FileNotFoundError as e:
    print(f"Dataset not found: {e}")
    print("Please ensure COCO dataset is downloaded to the 'data' directory.")
    print("You can use the download script: bash data/download_coco.sh")
except Exception as e:
    print(f"Training failed with error: {e}")
    print("This might be due to insufficient GPU memory. Try reducing batch_size or max_samples.")

2025-10-20 07:53:16,432 - INFO - Loading COCO train2017 dataset...


Starting CLIP Captioning Model Fine-tuning...
CLIP Model Type: <class 'models.clipModel.CLIPModel'>
Loading COCO dataset...


2025-10-20 07:53:16,926 - INFO - Loaded 500 COCO samples


Loaded 500 samples
Training samples: 400
Validation samples: 100
Pre-computing image embeddings...


100%|██████████| 13/13 [01:32<00:00,  7.08s/it]


Training failed with error: Can't load the model for 'gpt2'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'gpt2' is the correct path to a directory containing a file named pytorch_model.bin, tf_model.h5, model.ckpt or flax_model.msgpack.
This might be due to insufficient GPU memory. Try reducing batch_size or max_samples.


In [None]:
# Additional utility functions for comprehensive evaluation

def evaluate_caption_metrics(predictions, references):
    """
    Compute comprehensive caption evaluation metrics.
    
    Args:
        predictions: List of generated captions
        references: List of ground truth captions
    
    Returns:
        Dictionary with various metrics
    """
    from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
    from collections import Counter
    import re
    
    # BLEU scores with different n-grams
    bleu_1_scores = []
    bleu_2_scores = []
    bleu_3_scores = []
    bleu_4_scores = []
    
    # Other metrics
    exact_matches = 0
    token_overlaps = []
    
    smoothing = SmoothingFunction().method1
    
    for pred, ref in zip(predictions, references):
        pred_tokens = pred.lower().split()
        ref_tokens = ref.lower().split()
        
        # BLEU scores
        bleu_1 = sentence_bleu([ref_tokens], pred_tokens, weights=(1, 0, 0, 0), smoothing_function=smoothing)
        bleu_2 = sentence_bleu([ref_tokens], pred_tokens, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing)
        bleu_3 = sentence_bleu([ref_tokens], pred_tokens, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothing)
        bleu_4 = sentence_bleu([ref_tokens], pred_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing)
        
        bleu_1_scores.append(bleu_1)
        bleu_2_scores.append(bleu_2)
        bleu_3_scores.append(bleu_3)
        bleu_4_scores.append(bleu_4)
        
        # Exact match
        if pred.lower().strip() == ref.lower().strip():
            exact_matches += 1
        
        # Token overlap (Jaccard similarity)
        pred_set = set(pred_tokens)
        ref_set = set(ref_tokens)
        if len(pred_set.union(ref_set)) > 0:
            overlap = len(pred_set.intersection(ref_set)) / len(pred_set.union(ref_set))
            token_overlaps.append(overlap)
    
    return {
        'bleu_1': np.mean(bleu_1_scores),
        'bleu_2': np.mean(bleu_2_scores),
        'bleu_3': np.mean(bleu_3_scores),
        'bleu_4': np.mean(bleu_4_scores),
        'exact_match_rate': exact_matches / len(predictions),
        'token_overlap': np.mean(token_overlaps),
        'num_samples': len(predictions)
    }

def save_training_results(trained_model, losses, bleu_scores, val_samples, 
                         model_path="clip_caption_model.pt", results_path="training_results.json"):
    """
    Save training results and model for later use.
    
    Args:
        trained_model: Trained captioning model
        losses: Training losses
        bleu_scores: BLEU scores during training
        val_samples: Validation samples
        model_path: Path to save model
        results_path: Path to save training results
    """
    # Save model
    torch.save(trained_model.state_dict(), model_path)
    
    # Save training metrics
    results = {
        'training_losses': losses,
        'bleu_scores': bleu_scores,
        'final_bleu': bleu_scores[-1] if bleu_scores else 0.0,
        'num_epochs': len(losses),
        'validation_samples': len(val_samples)
    }
    
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"Model saved to: {model_path}")
    print(f"Results saved to: {results_path}")

def load_trained_model(model_path, clip_embed_size):
    """
    Load a previously trained captioning model.
    
    Args:
        model_path: Path to saved model
        clip_embed_size: Size of CLIP embeddings
    
    Returns:
        Loaded captioning model
    """
    model = ClipCaptionModel(
        prefix_length=10,
        clip_length=10,
        prefix_size=clip_embed_size
    ).to(device)
    
    model.load_state_dict(torch.load(model_path, map_location=device))
    return model

# Function to run comprehensive evaluation
def run_comprehensive_evaluation(trained_model, clip_model, val_samples, max_samples=100):
    """
    Run a comprehensive evaluation of the trained captioning model.
    
    Args:
        trained_model: Trained captioning model
        clip_model: CLIP model for embeddings
        val_samples: Validation samples
        max_samples: Maximum samples to evaluate
    
    Returns:
        Dictionary with comprehensive metrics
    """
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    
    print(f"Running comprehensive evaluation on {min(max_samples, len(val_samples))} samples...")
    
    eval_samples = val_samples[:max_samples]
    predictions = []
    references = []
    
    trained_model.eval()
    with torch.no_grad():
        for sample in tqdm(eval_samples, desc="Generating captions"):
            image_embedding = clip_model.encode_images([sample["image_path"]])[0].unsqueeze(0).to(device)
            generated_caption = generate_caption(trained_model, image_embedding, tokenizer)
            
            predictions.append(generated_caption)
            references.append(sample["text"])
    
    # Compute comprehensive metrics
    metrics = evaluate_caption_metrics(predictions, references)
    
    print("\nComprehensive Evaluation Results:")
    print("=" * 40)
    print(f"BLEU-1: {metrics['bleu_1']:.4f}")
    print(f"BLEU-2: {metrics['bleu_2']:.4f}")
    print(f"BLEU-3: {metrics['bleu_3']:.4f}")
    print(f"BLEU-4: {metrics['bleu_4']:.4f}")
    print(f"Exact Match Rate: {metrics['exact_match_rate']:.4f}")
    print(f"Token Overlap: {metrics['token_overlap']:.4f}")
    print(f"Number of Samples: {metrics['num_samples']}")
    
    return metrics, predictions, references