# MaterialVision Model Loading Demo

This notebook demonstrates how to load and use the different vision-language models available in the MaterialVision project:

- **CLIPP-SciBERT**: CLIPP model with SciBERT text encoder
- **CLIPP-DistilBERT**: CLIPP model with DistilBERT text encoder  
- **MobileCLIP**: Apple's MobileCLIP model
- **BLIP**: Salesforce's BLIP model for image-text retrieval

Each model has its own loading function that handles checkpoint loading, device placement, and provides a consistent interface.

## 1. Import Required Libraries

First, let's import all the necessary libraries and modules.

In [1]:
import sys
import os
from pathlib import Path
import importlib.util
import torch
import numpy as np
from PIL import Image
import warnings

# Since we're already in the webapp directory, we can import models.py directly
# No need to add paths since models.py is in the same directory

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

Using device: cuda


## 2. Load Functions from External Files

Now let's import the model loading functions from the `models.py` file.

In [2]:
try:
    # Import model loading functions from models.py
    from models import (
        load_clipp_scibert,
        load_clipp_distilbert, 
        load_mobileclip,
        load_blip
    )
    
    print("‚úÖ Successfully imported model loading functions:")
    print("  - load_clipp_scibert")
    print("  - load_clipp_distilbert")
    print("  - load_mobileclip") 
    print("  - load_blip")
    
except ImportError as e:
    print(f"‚ùå Error importing model functions: {e}")
    print("Make sure you're running this notebook from the MaterialVision root directory")
    print("and that the webapp/models.py file exists.")

Adding to path: /home/jipengsun/MaterialVision/models/CLIPP_allenai
‚úÖ Successfully imported CLIPP SciBERT
Adding to path: /home/jipengsun/MaterialVision/models/CLIPP_bert
‚úÖ Successfully imported CLIPP DistilBERT
Adding to path: /home/jipengsun/MaterialVision/models/Apple_MobileCLIP
‚úÖ Successfully imported MobileCLIP
Adding to path: /home/jipengsun/MaterialVision/models/Salesforce
‚úÖ Successfully imported BLIP
‚úÖ Successfully imported model loading functions:
  - load_clipp_scibert
  - load_clipp_distilbert
  - load_mobileclip
  - load_blip


## 3. Call Loaded Functions with Sample Data

Let's check for available checkpoints and demonstrate loading each model.

In [3]:
# Define checkpoint paths (relative to webapp directory, go up one level to access models)
checkpoint_paths = {
    'clipp_scibert': '../models/CLIPP_allenai/checkpoints/best_clipp.pth',
    'clipp_distilbert': '../models/CLIPP_bert/checkpoints/best_clipp_bert.pth', 
    'mobileclip': '../models/Apple_MobileCLIP/checkpoints/best_clipp_apple.pth',
    'blip': '../models/Salesforce/checkpoints_blip/best_blip.pth'
}

# Check which checkpoints exist
available_models = {}
for model_name, path in checkpoint_paths.items():
    full_path = Path(path)
    if full_path.exists():
        available_models[model_name] = str(full_path)
        print(f"‚úÖ {model_name}: {path}")
    else:
        print(f"‚ùå {model_name}: {path} (not found)")

print(f"\nFound {len(available_models)} available model checkpoints.")

‚úÖ clipp_scibert: ../models/CLIPP_allenai/checkpoints/best_clipp.pth
‚úÖ clipp_distilbert: ../models/CLIPP_bert/checkpoints/best_clipp_bert.pth
‚úÖ mobileclip: ../models/Apple_MobileCLIP/checkpoints/best_clipp_apple.pth
‚úÖ blip: ../models/Salesforce/checkpoints_blip/best_blip.pth

Found 4 available model checkpoints.


### 3.1 Load CLIPP-SciBERT Model

In [4]:
if 'clipp_scibert' in available_models:
    try:
        print("Loading CLIPP-SciBERT model...")
        clipp_scibert_model, clipp_scibert_tokenizer, clipp_scibert_dataset = load_clipp_scibert(
            checkpoint_path=available_models['clipp_scibert'],
            device=str(device)
        )
        
        print("‚úÖ CLIPP-SciBERT model loaded successfully!")
        print(f"   Model device: {next(clipp_scibert_model.parameters()).device}")
        print(f"   Tokenizer type: {type(clipp_scibert_tokenizer).__name__}")
        print(f"   Dataset type: {type(clipp_scibert_dataset).__name__}")        
        # Test tokenization
        sample_text = "The chemical formula is UGe2Pt2. The mbj_bandgap value is 0.0."
        caption, input_ids, attention_mask = clipp_scibert_dataset.prepare_caption(sample_text)
        print(f"sample input_ids: {input_ids}")
        print(f"sample attention_mask: {attention_mask}")

        # Test text embedding
        txt_emb = clipp_scibert_model.get_text_features(input_ids.view(1,-1).to(device), attention_mask.view(1,-1).to(device))
        print(f"Text embedding shape: {txt_emb.shape}")
    except Exception as e:
        print(f"‚ùå Error loading CLIPP-SciBERT: {e}")
else:
    print("‚è≠Ô∏è  CLIPP-SciBERT checkpoint not available, skipping...")

Loading CLIPP-SciBERT model...


2025-11-09 23:18:23,539 INFO: Loading pretrained weights from Hugging Face hub (timm/vit_base_patch16_224.augreg2_in21k_ft_in1k)
2025-11-09 23:18:23,581 INFO: [timm/vit_base_patch16_224.augreg2_in21k_ft_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


‚úÖ CLIPP-SciBERT model loaded successfully!
   Model device: cuda:0
   Tokenizer type: BertTokenizerFast
   Dataset type: ImageTextDataset
sample input_ids: tensor([ 102,  158,  504,  170, 1240,  170, 3471,  244,  205,  244,  103,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0

### 3.2 Load CLIPP-DistilBERT Model

In [5]:
if 'clipp_distilbert' in available_models:
    try:
        print("Loading CLIPP-DistilBERT model...")
        clipp_distilbert_model, clipp_distilbert_tokenizer, clipp_distilbert_dataset = load_clipp_distilbert(
            checkpoint_path=available_models['clipp_distilbert'],
            device=str(device)
        )
        
        print("‚úÖ CLIPP-DistilBERT model loaded successfully!")
        print(f"   Model device: {next(clipp_distilbert_model.parameters()).device}")
        print(f"   Tokenizer type: {type(clipp_distilbert_tokenizer).__name__}")
        print(f"   Dataset type: {type(clipp_distilbert_dataset).__name__}")

        # Test tokenization
        sample_text = "The chemical formula is UGe2Pt2. The mbj_bandgap value is 0.0."
        caption, input_ids, attention_mask = clipp_distilbert_dataset.prepare_caption(sample_text)
        embeddings = clipp_distilbert_model.get_text_features(input_ids.view(1,-1).to(device), attention_mask.view(1,-1).to(device))
        print(f"Text embedding shape: {embeddings.shape}")
        
    except Exception as e:
        print(f"‚ùå Error loading CLIPP-DistilBERT: {e}")
else:
    print("‚è≠Ô∏è  CLIPP-DistilBERT checkpoint not available, skipping...")

Loading CLIPP-DistilBERT model...


2025-11-09 23:18:40,711 INFO: Loading pretrained weights from Hugging Face hub (timm/resnet50.a1_in1k)
2025-11-09 23:18:40,756 INFO: [timm/resnet50.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


‚úÖ CLIPP-DistilBERT model loaded successfully!
   Model device: cuda:0
   Tokenizer type: DistilBertTokenizer
   Dataset type: ImageTextDataset
Text embedding shape: torch.Size([1, 256])


### 3.3 Load MobileCLIP Model

In [6]:
if 'mobileclip' in available_models:
    try:
        print("Loading MobileCLIP model...")
        mobileclip_model, mobileclip_tokenizer, mobileclip_dataset = load_mobileclip(
            checkpoint_path=available_models['mobileclip'],
            device=str(device)
        )
        
        print("‚úÖ MobileCLIP model loaded successfully!")
        print(f"   Model device: {next(mobileclip_model.parameters()).device}")
        print(f"   Tokenizer type: {type(mobileclip_tokenizer)}")
        print(f"   Dataset type: {type(mobileclip_dataset).__name__}")

        # Test tokenization (MobileCLIP uses different tokenization)
        sample_text = "The chemical formula is UGe2Pt2. The mbj_bandgap value is 0.0."
        caption, text_tokens = mobileclip_dataset.prepare_caption(sample_text)
        embeddings = mobileclip_model.get_text_features(text_tokens.to(device))
        print(f"Text embedding shape: {embeddings.shape}")

    except Exception as e:
        print(f"‚ùå Error loading MobileCLIP: {e}")
else:
    print("‚è≠Ô∏è  MobileCLIP checkpoint not available, skipping...")

2025-11-09 23:18:49,610 INFO: Loaded MobileCLIP-S2 model config.


Loading MobileCLIP model...


2025-11-09 23:18:51,001 INFO: Loading pretrained MobileCLIP-S2 weights (datacompdr).


‚úÖ MobileCLIP model loaded successfully!
   Model device: cuda:0
   Tokenizer type: <class 'open_clip.tokenizer.SimpleTokenizer'>
   Dataset type: ImageTextDataset
Text embedding shape: torch.Size([1, 256])


### 3.4 Load BLIP Model

In [12]:
if 'blip' in available_models:
    try:
        print("Loading BLIP model...")
        blip_model, blip_processor, blip_dataset = load_blip(
            checkpoint_path=available_models['blip'],
            device=str(device)
        )
        
        print("‚úÖ BLIP model loaded successfully!")
        print(f"   Model device: {next(blip_model.parameters()).device}")
        print(f"   Processor type: {type(blip_processor).__name__}")
        print(f"   Dataset type: {type(blip_dataset).__name__}")
        
        # Test text processing
        sample_text = "The chemical formula is UGe2Pt2. The mbj_bandgap value is 0.0."
        caption, input_ids, attention_mask = blip_dataset.prepare_caption(sample_text)
        embeddings = blip_model.get_text_features(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device))
        print(f"Text embedding shape: {embeddings.shape}")
        
    except Exception as e:
        print(f"‚ùå Error loading BLIP: {e}")
else:
    print("‚è≠Ô∏è  BLIP checkpoint not available, skipping...")

Loading BLIP model...
‚úÖ BLIP model loaded successfully!
   Model device: cuda:0
   Processor type: BlipProcessor
   Dataset type: ImageTextDataset
Text embedding shape: torch.Size([17, 256])


In [15]:
caption, attention_mask.shape, embeddings.shape

('1 U 2 Ge 2 Pt 0.0', torch.Size([17, 3]), torch.Size([17, 256]))

## 4. Display Function Results

Let's demonstrate how to use the loaded models for text encoding and feature extraction.

In [None]:
# Sample material science texts
sample_texts = [
    "Silicon dioxide has excellent optical properties",
    "Graphene exhibits high electrical conductivity", 
    "Titanium dioxide is a versatile photocatalyst",
    "Perovskite materials for solar cell applications"
]

print("üî¨ Testing text feature extraction with loaded models:\n")

# Test each loaded model
loaded_models = []

# Check CLIPP-SciBERT
if 'clipp_scibert' in available_models and 'clipp_scibert_model' in locals():
    loaded_models.append(('CLIPP-SciBERT', clipp_scibert_model, clipp_scibert_tokenizer))

# Check CLIPP-DistilBERT  
if 'clipp_distilbert' in available_models and 'clipp_distilbert_model' in locals():
    loaded_models.append(('CLIPP-DistilBERT', clipp_distilbert_model, clipp_distilbert_tokenizer))

# Check MobileCLIP
if 'mobileclip' in available_models and 'mobileclip_model' in locals():
    loaded_models.append(('MobileCLIP', mobileclip_model, mobileclip_tokenizer))

# Check BLIP
if 'blip' in available_models and 'blip_model' in locals():
    loaded_models.append(('BLIP', blip_model, blip_processor))

print(f"Testing with {len(loaded_models)} successfully loaded models:")
for name, _, _ in loaded_models:
    print(f"  ‚úì {name}")
print()

In [None]:
# Test text feature extraction for each model
for model_name, model, tokenizer_or_processor in loaded_models:
    print(f"üìù Testing {model_name}:")
    
    try:
        with torch.no_grad():
            if model_name == 'BLIP':
                # BLIP uses processor
                processed = tokenizer_or_processor(text=sample_texts, return_tensors='pt', padding=True, truncation=True)
                # Move tensors to device
                for key in processed:
                    if isinstance(processed[key], torch.Tensor):
                        processed[key] = processed[key].to(device)
                features = model.get_text_features(**processed)
            
            elif model_name == 'MobileCLIP':
                # MobileCLIP uses different tokenization
                tokens = tokenizer_or_processor(sample_texts).to(device)
                features = model.get_text_features(tokens)
            
            else:
                # CLIPP models use standard tokenization
                tokens = tokenizer_or_processor(sample_texts, return_tensors='pt', padding=True, truncation=True)
                # Move tensors to device
                for key in tokens:
                    tokens[key] = tokens[key].to(device)
                features = model.get_text_features(tokens['input_ids'], tokens['attention_mask'])
            
            print(f"   ‚úÖ Text features shape: {features.shape}")
            print(f"   üìä Feature statistics:")
            print(f"      Mean: {features.mean().item():.4f}")
            print(f"      Std:  {features.std().item():.4f}")
            print(f"      Min:  {features.min().item():.4f}")
            print(f"      Max:  {features.max().item():.4f}")
            
    except Exception as e:
        print(f"   ‚ùå Error: {e}")
    
    print()

## 5. Error Handling for Function Calls

Let's demonstrate robust error handling when working with these models.

In [None]:
def safe_model_loading(model_name, load_function, checkpoint_path, device):
    """
    Safely load a model with comprehensive error handling.
    
    Args:
        model_name: Name of the model for logging
        load_function: Function to load the model
        checkpoint_path: Path to model checkpoint
        device: Device to load model on
    
    Returns:
        tuple: (success, model_data, error_message)
    """
    try:
        print(f"üîÑ Attempting to load {model_name}...")
        
        # Check if checkpoint exists
        if not Path(checkpoint_path).exists():
            raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
        
        # Check if checkpoint is readable
        if not os.access(checkpoint_path, os.R_OK):
            raise PermissionError(f"Cannot read checkpoint: {checkpoint_path}")
        
        # Attempt to load the model
        model_data = load_function(checkpoint_path, device)
        
        # Validate model data
        if not model_data or len(model_data) != 3:
            raise ValueError("Invalid model data returned from load function")
        
        model, tokenizer, dataset = model_data
        
        # Basic validation
        if model is None:
            raise ValueError("Model is None")
        
        if tokenizer is None:
            raise ValueError("Tokenizer/Processor is None")
            
        print(f"‚úÖ {model_name} loaded successfully!")
        return True, model_data, None
        
    except FileNotFoundError as e:
        error_msg = f"File not found: {e}"
        print(f"‚ùå {model_name}: {error_msg}")
        return False, None, error_msg
        
    except PermissionError as e:
        error_msg = f"Permission error: {e}"
        print(f"‚ùå {model_name}: {error_msg}")
        return False, None, error_msg
        
    except ImportError as e:
        error_msg = f"Import error (missing dependencies): {e}"
        print(f"‚ùå {model_name}: {error_msg}")
        return False, None, error_msg
        
    except RuntimeError as e:
        error_msg = f"Runtime error (possibly CUDA/memory): {e}"
        print(f"‚ùå {model_name}: {error_msg}")
        return False, None, error_msg
        
    except Exception as e:
        error_msg = f"Unexpected error: {type(e).__name__}: {e}"
        print(f"‚ùå {model_name}: {error_msg}")
        return False, None, error_msg

# Test safe loading with a non-existent checkpoint
print("üß™ Testing error handling with invalid checkpoint:")
success, data, error = safe_model_loading(
    "Test Model", 
    load_clipp_scibert, 
    "non_existent_checkpoint.pth", 
    str(device)
)
print(f"   Success: {success}")
print(f"   Error: {error}")

## Summary

This notebook demonstrated how to:

1. **Import model loading functions** from the `models.py` file (located in the same webapp directory)
2. **Check for available checkpoints** in the `../models/` directory and handle missing files gracefully
3. **Load each model type** (CLIPP-SciBERT, CLIPP-DistilBERT, MobileCLIP, BLIP) with proper error handling
4. **Extract text features** using the loaded models with sample material science texts
5. **Load validation data** from `../../data/alpaca_mbj_bandgap_test.csv`
6. **Generate text and image embeddings** for all loaded models on the validation set
7. **Save embeddings** to the `./embeddings/` directory in multiple formats (pickle, numpy, text)
8. **Compute retrieval metrics** (text-to-image and image-to-text) for performance evaluation
9. **Implement robust error handling** for model loading and embedding generation operations

### Key Takeaways:

- Each model has a consistent interface through the `get_text_features()` and `get_image_features()` methods
- Different models use different tokenization approaches (AutoTokenizer vs open_clip tokenizer vs Processor)
- Embeddings are saved in an organized directory structure under `./embeddings/[model_name]/`
- Retrieval metrics provide quantitative comparison between different models
- Proper error handling is crucial when working with large models and checkpoints
- All models can be used on both CPU and GPU devices
- The notebook is designed to run from the webapp directory with relative paths to the models

### Generated Files:

- `./embeddings/[model_name]/validation_embeddings.pkl` - Complete embeddings data
- `./embeddings/[model_name]/text_embeddings.npy` - Text embeddings as numpy array
- `./embeddings/[model_name]/image_embeddings.npy` - Image embeddings as numpy array  
- `./embeddings/[model_name]/captions.txt` - Text captions
- `./embeddings/retrieval_metrics.pkl` - Computed retrieval metrics for all models

### Next Steps:

- Use embeddings for similarity search between text and image queries
- Implement embedding-based material property prediction
- Compare embedding quality across different models
- Fine-tune models on domain-specific material science data
- Build web applications using the saved embeddings

## 6. Load Validation Data and Generate Embeddings

Now let's load the validation data and use each model to generate text and image embeddings, saving them to the embeddings folder.

In [9]:
import pandas as pd
import os
from torch.utils.data import DataLoader
from pathlib import Path
import pickle
import numpy as np

# Define validation data path
VAL_CSV = Path('../../data/alpaca_mbj_bandgap_test.csv')
BATCH_SIZE = 32

# Create embeddings directory
embeddings_dir = Path('./embeddings')
embeddings_dir.mkdir(exist_ok=True)

print(f"Validation CSV path: {VAL_CSV}")
print(f"Validation CSV exists: {VAL_CSV.exists()}")
print(f"Embeddings directory: {embeddings_dir}")

# Load validation data
if VAL_CSV.exists():
    val_df = pd.read_csv(VAL_CSV)
    print(f"‚úÖ Loaded validation data with {len(val_df)} samples")
    print(f"   Columns: {list(val_df.columns)}")
    print(f"   Sample columns preview:")
    for col in val_df.columns[:5]:  # Show first 5 columns
        print(f"     {col}: {val_df[col].iloc[0] if len(str(val_df[col].iloc[0])) < 50 else str(val_df[col].iloc[0])[:50] + '...'}")
else:
    print(f"‚ùå Validation CSV not found at {VAL_CSV}")
    val_df = None

Validation CSV path: ../../data/alpaca_mbj_bandgap_test.csv
Validation CSV exists: False
Embeddings directory: embeddings
‚ùå Validation CSV not found at ../../data/alpaca_mbj_bandgap_test.csv


In [None]:
# Function to generate unified embeddings for each model
def generate_embeddings_for_model(model_name, model, tokenizer_or_processor, dataset_class, val_df, device):
    """
    Generate unified embeddings that combine both text and image information into single embeddings.
    
    Args:
        model_name: Name of the model (for saving files)
        model: The loaded model
        tokenizer_or_processor: Tokenizer or processor for the model
        dataset_class: Dataset class for creating data loader
        val_df: Validation dataframe
        device: Device to run on
    
    Returns:
        tuple: (unified_embeddings, sample_ids, captions, text_embeddings, image_embeddings)
    """
    print(f"\nüîÑ Generating unified embeddings for {model_name}...")
    
    try:
        # Create dataset and dataloader based on model type
        if model_name == 'MobileCLIP':
            # For MobileCLIP, need to import the specific preprocessing
            sys.path.append('../models/Apple_MobileCLIP')
            import open_clip
            
            # Load MobileCLIP-S2 preprocessor  
            _, _, preprocess_s2 = open_clip.create_model_and_transforms('MobileCLIP-S2', pretrained='datacompdr')
            
            val_dataset = dataset_class(val_df, tokenizer_or_processor, preprocess_s2, train=False)
            
        elif model_name == 'BLIP':
            # BLIP uses a processor for both text and images
            val_dataset = dataset_class(val_df, tokenizer_or_processor, train=False)
            
        else:
            # CLIPP models use standard tokenizer + transform
            # We'll need to define a transform for images
            from torchvision import transforms
            
            # Standard image preprocessing for CLIPP models
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            
            val_dataset = dataset_class(val_df, tokenizer_or_processor, transform, train=False)
        
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
        
        # Generate embeddings with proper tracking
        text_embeddings_list = []
        image_embeddings_list = []
        captions = []
        sample_ids = []  # Track unique IDs for each sample
        
        model.eval()
        with torch.no_grad():
            global_sample_idx = 0  # Global counter for unique IDs
            
            for i, batch in enumerate(val_loader):
                print(f"   Processing batch {i+1}/{len(val_loader)}", end='\r')
                
                if model_name == 'MobileCLIP':
                    # MobileCLIP specific processing
                    images = batch['image'].to(device)
                    text_tokens = batch['text_tokens'].to(device)
                    texts = batch['caption']
                    
                    # Get embeddings using the model's forward method
                    img_emb, txt_emb = model(images, text_tokens)
                    
                elif model_name == 'BLIP':
                    # BLIP specific processing
                    images = batch['image'].to(device)
                    texts = batch['caption']
                    
                    # Process text
                    text_inputs = tokenizer_or_processor(text=texts, return_tensors='pt', padding=True, truncation=True)
                    for key in text_inputs:
                        if isinstance(text_inputs[key], torch.Tensor):
                            text_inputs[key] = text_inputs[key].to(device)
                    
                    # Get embeddings
                    img_emb = model.get_image_features(images)
                    txt_emb = model.get_text_features(**text_inputs)
                    
                else:
                    # CLIPP models
                    images = batch['image'].to(device)
                    texts = batch['caption']
                    
                    # Tokenize text
                    text_inputs = tokenizer_or_processor(texts, return_tensors='pt', padding=True, truncation=True)
                    for key in text_inputs:
                        text_inputs[key] = text_inputs[key].to(device)
                    
                    # Get embeddings
                    img_emb = model.get_image_features(images)
                    txt_emb = model.get_text_features(text_inputs['input_ids'], text_inputs['attention_mask'])
                
                # Collect embeddings and metadata with proper ordering
                batch_size = len(texts)
                batch_ids = list(range(global_sample_idx, global_sample_idx + batch_size))
                
                text_embeddings_list.append(txt_emb.cpu())
                image_embeddings_list.append(img_emb.cpu())
                captions.extend(texts)
                sample_ids.extend(batch_ids)
                
                global_sample_idx += batch_size
        
        # Concatenate all embeddings
        text_embeddings = torch.cat(text_embeddings_list, dim=0)
        image_embeddings = torch.cat(image_embeddings_list, dim=0)
        
        # Create unified embeddings by combining text and image embeddings
        print(f"\n   üîÑ Creating unified embeddings...")
        
        # Method 1: Simple concatenation
        unified_embeddings_concat = torch.cat([text_embeddings, image_embeddings], dim=1)
        
        # Method 2: Weighted average (you can adjust weights)
        text_weight = 0.5
        image_weight = 0.5
        
        # Ensure both embeddings have same dimension by projecting to common space if needed
        if text_embeddings.shape[1] != image_embeddings.shape[1]:
            # Project to smaller dimension
            target_dim = min(text_embeddings.shape[1], image_embeddings.shape[1])
            
            # Simple linear projection (you could use learned projections too)
            if text_embeddings.shape[1] > target_dim:
                text_proj = text_embeddings[:, :target_dim]
            else:
                text_proj = torch.cat([text_embeddings, torch.zeros(text_embeddings.shape[0], target_dim - text_embeddings.shape[1])], dim=1)
                
            if image_embeddings.shape[1] > target_dim:
                image_proj = image_embeddings[:, :target_dim]
            else:
                image_proj = torch.cat([image_embeddings, torch.zeros(image_embeddings.shape[0], target_dim - image_embeddings.shape[1])], dim=1)
        else:
            text_proj = text_embeddings
            image_proj = image_embeddings
        
        unified_embeddings_avg = text_weight * text_proj + image_weight * image_proj
        
        # Method 3: Element-wise operations
        unified_embeddings_multiply = text_proj * image_proj  # Element-wise multiplication
        unified_embeddings_max = torch.max(text_proj, image_proj)  # Element-wise maximum
        
        print(f"   ‚úÖ Generated unified embeddings:")
        print(f"      Text embeddings: {text_embeddings.shape}")
        print(f"      Image embeddings: {image_embeddings.shape}")
        print(f"      Unified (concat): {unified_embeddings_concat.shape}")
        print(f"      Unified (avg): {unified_embeddings_avg.shape}")
        print(f"      Unified (multiply): {unified_embeddings_multiply.shape}")
        print(f"      Unified (max): {unified_embeddings_max.shape}")
        print(f"      Number of captions: {len(captions)}")
        print(f"      Number of sample IDs: {len(sample_ids)}")
        print(f"      Sample ID range: {min(sample_ids)} - {max(sample_ids)}")
        
        # Use concatenation as the default unified embedding (most comprehensive)
        unified_embeddings = unified_embeddings_concat
        
        # Verify pairing consistency
        assert len(captions) == len(sample_ids) == unified_embeddings.shape[0] == text_embeddings.shape[0] == image_embeddings.shape[0], \
            f"Inconsistent lengths: captions={len(captions)}, ids={len(sample_ids)}, " \
            f"unified_emb={unified_embeddings.shape[0]}, text_emb={text_embeddings.shape[0]}, img_emb={image_embeddings.shape[0]}"
        
        # Save embeddings with proper pairing information
        model_embeddings_dir = embeddings_dir / model_name.lower().replace('-', '_')
        model_embeddings_dir.mkdir(exist_ok=True)
        
        # Create comprehensive embeddings data with unified embeddings
        embeddings_data = {
            'unified_embeddings': unified_embeddings.numpy(),  # Main unified embedding
            'unified_embeddings_concat': unified_embeddings_concat.numpy(),
            'unified_embeddings_avg': unified_embeddings_avg.numpy(), 
            'unified_embeddings_multiply': unified_embeddings_multiply.numpy(),
            'unified_embeddings_max': unified_embeddings_max.numpy(),
            'text_embeddings': text_embeddings.numpy(),  # Keep originals for analysis
            'image_embeddings': image_embeddings.numpy(),
            'captions': captions,
            'sample_ids': sample_ids,
            'model_name': model_name,
            'fusion_info': {
                'default_method': 'concatenation',
                'text_weight': text_weight,
                'image_weight': image_weight,
                'original_text_dim': text_embeddings.shape[1],
                'original_image_dim': image_embeddings.shape[1],
                'unified_dim': unified_embeddings.shape[1]
            },
            'pairing_info': {
                'description': 'unified_embeddings[i] combines text and image info for captions[i] and sample_ids[i]',
                'total_pairs': len(sample_ids),
                'embedding_dim_unified': unified_embeddings.shape[1],
                'embedding_dim_text': text_embeddings.shape[1],
                'embedding_dim_image': image_embeddings.shape[1]
            }
        }
        
        # Save as pickle with all metadata
        pickle_path = model_embeddings_dir / 'validation_embeddings.pkl'
        with open(pickle_path, 'wb') as f:
            pickle.dump(embeddings_data, f)
        
        # Save different versions of unified embeddings
        np.save(model_embeddings_dir / 'unified_embeddings.npy', unified_embeddings.numpy())
        np.save(model_embeddings_dir / 'unified_embeddings_concat.npy', unified_embeddings_concat.numpy())
        np.save(model_embeddings_dir / 'unified_embeddings_avg.npy', unified_embeddings_avg.numpy())
        np.save(model_embeddings_dir / 'unified_embeddings_multiply.npy', unified_embeddings_multiply.numpy())
        np.save(model_embeddings_dir / 'unified_embeddings_max.npy', unified_embeddings_max.numpy())
        
        # Save original embeddings for comparison
        np.save(model_embeddings_dir / 'text_embeddings.npy', text_embeddings.numpy())
        np.save(model_embeddings_dir / 'image_embeddings.npy', image_embeddings.numpy())
        np.save(model_embeddings_dir / 'sample_ids.npy', np.array(sample_ids))
        
        # Save captions and IDs as structured text file
        with open(model_embeddings_dir / 'captions_with_ids.txt', 'w') as f:
            f.write("# Format: sample_id,caption\n")
            for sample_id, caption in zip(sample_ids, captions):
                # Escape commas in captions
                escaped_caption = caption.replace(',', '\\,')
                f.write(f"{sample_id},{escaped_caption}\n")
        
        # Save comprehensive information
        with open(model_embeddings_dir / 'embedding_info.txt', 'w') as f:
            f.write(f"Model: {model_name}\n")
            f.write(f"Total samples: {len(sample_ids)}\n")
            f.write(f"Sample ID range: {min(sample_ids)} - {max(sample_ids)}\n\n")
            
            f.write(f"UNIFIED EMBEDDINGS:\n")
            f.write(f"  Default method: concatenation\n")
            f.write(f"  Unified dimension: {unified_embeddings.shape[1]}\n")
            f.write(f"  Available fusion methods:\n")
            f.write(f"    - Concatenation: {unified_embeddings_concat.shape[1]}D\n")
            f.write(f"    - Weighted average: {unified_embeddings_avg.shape[1]}D (text:{text_weight}, image:{image_weight})\n")
            f.write(f"    - Element-wise multiply: {unified_embeddings_multiply.shape[1]}D\n")
            f.write(f"    - Element-wise max: {unified_embeddings_max.shape[1]}D\n\n")
            
            f.write(f"ORIGINAL EMBEDDINGS:\n")
            f.write(f"  Text embedding dimension: {text_embeddings.shape[1]}\n")
            f.write(f"  Image embedding dimension: {image_embeddings.shape[1]}\n\n")
            
            f.write(f"PAIRING RULE:\n")
            f.write(f"  unified_embeddings[i] combines info from text_embeddings[i] + image_embeddings[i]\n")
            f.write(f"  for captions[i] and sample_ids[i]\n\n")
            
            f.write(f"FILES:\n")
            f.write(f"  - validation_embeddings.pkl: Complete data with all embedding variants\n")
            f.write(f"  - unified_embeddings.npy: Main unified embeddings (concatenation)\n")
            f.write(f"  - unified_embeddings_*.npy: Different fusion methods\n")
            f.write(f"  - text_embeddings.npy: Original text embeddings\n")
            f.write(f"  - image_embeddings.npy: Original image embeddings\n")
            f.write(f"  - sample_ids.npy: Sample IDs\n")
            f.write(f"  - captions_with_ids.txt: Captions with corresponding IDs\n")
        
        print(f"   üíæ Saved unified embeddings to: {model_embeddings_dir}")
        print(f"      Main files: validation_embeddings.pkl, unified_embeddings.npy")
        print(f"      Fusion variants: unified_embeddings_concat.npy, unified_embeddings_avg.npy,")
        print(f"                      unified_embeddings_multiply.npy, unified_embeddings_max.npy")
        print(f"      Original: text_embeddings.npy, image_embeddings.npy")
        print(f"      Metadata: sample_ids.npy, captions_with_ids.txt, embedding_info.txt")
        
        return unified_embeddings, sample_ids, captions, text_embeddings, image_embeddings
        
    except Exception as e:
        print(f"   ‚ùå Error generating unified embeddings for {model_name}: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None, None, None

print("‚úÖ Unified embedding generation function defined")

In [None]:
# Generate unified embeddings for all loaded models
if val_df is not None and len(loaded_models) > 0:
    print(f"üöÄ Starting unified embedding generation for {len(loaded_models)} models...")
    
    embedding_results = {}
    
    # Process each loaded model
    for model_name, model, tokenizer_or_processor in loaded_models:
        # Get the appropriate dataset class
        if model_name == 'CLIPP-SciBERT':
            dataset_class = ImageTextDatasetSciBERT
        elif model_name == 'CLIPP-DistilBERT':
            dataset_class = ImageTextDatasetDistilBERT
        elif model_name == 'MobileCLIP':
            dataset_class = ImageTextDatasetMobileCLIP
        elif model_name == 'BLIP':
            dataset_class = ImageTextDatasetBLIP
        else:
            print(f"‚ö†Ô∏è  Unknown model type: {model_name}, skipping...")
            continue
        
        # Generate unified embeddings
        unified_emb, sample_ids, captions, text_emb, img_emb = generate_embeddings_for_model(
            model_name, model, tokenizer_or_processor, dataset_class, val_df, device
        )
        
        if unified_emb is not None:
            embedding_results[model_name] = {
                'unified_embeddings': unified_emb,
                'text_embeddings': text_emb,  # Keep for comparison
                'image_embeddings': img_emb,  # Keep for comparison
                'sample_ids': sample_ids,
                'captions': captions
            }
            
            # Verify pairing for this model
            print(f"   üîç Pairing verification for {model_name}:")
            print(f"      - Unified embedding shape: {unified_emb.shape}")
            print(f"      - Text/Image embedding shapes: {text_emb.shape}, {img_emb.shape}")
            print(f"      - Caption count matches: {len(captions) == unified_emb.shape[0]}")
            print(f"      - Sample ID count matches: {len(sample_ids) == unified_emb.shape[0]}")
            print(f"      - Sample IDs are unique: {len(set(sample_ids)) == len(sample_ids)}")
            
            # Show fusion information
            unified_dim = unified_emb.shape[1]
            text_dim = text_emb.shape[1]
            img_dim = img_emb.shape[1]
            print(f"      - Fusion: {text_dim}D (text) + {img_dim}D (image) ‚Üí {unified_dim}D (unified)")
    
    print(f"\nüéâ Completed unified embedding generation!")
    print(f"‚úÖ Successfully generated unified embeddings for {len(embedding_results)} models:")
    for model_name in embedding_results.keys():
        print(f"   - {model_name}")
    
    # Show embedding directory structure
    print(f"\nüìÅ Embeddings directory structure:")
    for item in embeddings_dir.iterdir():
        if item.is_dir():
            print(f"   üìÇ {item.name}/")
            for file in sorted(item.iterdir()):
                if file.is_file():
                    file_size = file.stat().st_size / (1024*1024)  # MB
                    print(f"      üìÑ {file.name} ({file_size:.1f} MB)")
                    
    # Show unified embedding statistics
    print(f"\nüìä Unified Embedding Statistics:")
    for model_name, results in embedding_results.items():
        unified = results['unified_embeddings']
        print(f"   {model_name}:")
        print(f"     Shape: {unified.shape}")
        print(f"     Mean: {unified.mean():.4f}")
        print(f"     Std: {unified.std():.4f}")
        print(f"     Min: {unified.min():.4f}")
        print(f"     Max: {unified.max():.4f}")

else:
    if val_df is None:
        print("‚ùå Cannot generate embeddings: validation data not loaded")
    if len(loaded_models) == 0:
        print("‚ùå Cannot generate embeddings: no models loaded")

### 6.1 Compute Retrieval Metrics on Validation Set

Let's compute retrieval metrics (text-to-image and image-to-text) for each model using the generated embeddings.

In [None]:
def compute_unified_similarity_metrics(unified_embeddings, k_values=[1, 5, 10]):
    """
    Compute similarity metrics for unified embeddings (self-similarity analysis).
    
    Args:
        unified_embeddings: tensor of shape (N, D) - unified embeddings
        k_values: list of k values for top-k analysis
    
    Returns:
        dict: Dictionary with similarity metrics
    """
    # Normalize embeddings for cosine similarity
    normalized_emb = unified_embeddings / torch.norm(unified_embeddings, dim=1, keepdims=True)
    
    # Compute similarity matrix
    similarity_matrix = normalized_emb @ normalized_emb.T  # (N, N)
    
    # Remove diagonal (self-similarity = 1.0)
    mask = torch.eye(similarity_matrix.size(0), dtype=torch.bool)
    similarity_matrix_no_diag = similarity_matrix.masked_fill(mask, float('-inf'))
    
    results = {}
    
    # Compute statistics
    results['similarity_stats'] = {
        'mean': similarity_matrix_no_diag[~mask].mean().item(),
        'std': similarity_matrix_no_diag[~mask].std().item(),
        'min': similarity_matrix_no_diag[~mask].min().item(),
        'max': similarity_matrix_no_diag[~mask].max().item()
    }
    
    # Find most similar pairs (excluding self)
    results['top_similarities'] = {}
    for k in k_values:
        top_k_per_sample = torch.topk(similarity_matrix_no_diag, k, dim=1)
        avg_top_k = top_k_per_sample.values.mean(dim=1).mean().item()
        results['top_similarities'][f'avg_top_{k}'] = avg_top_k
    
    return results, similarity_matrix

def compare_fusion_methods(text_embeddings, image_embeddings):
    """
    Compare different methods of fusing text and image embeddings.
    
    Args:
        text_embeddings: Text embeddings
        image_embeddings: Image embeddings
    
    Returns:
        dict: Dictionary with different fusion results
    """
    # Ensure same dimensions for fair comparison
    min_dim = min(text_embeddings.shape[1], image_embeddings.shape[1])
    text_proj = text_embeddings[:, :min_dim]
    image_proj = image_embeddings[:, :min_dim]
    
    fusion_methods = {}
    
    # Method 1: Concatenation
    fusion_methods['concatenation'] = torch.cat([text_embeddings, image_embeddings], dim=1)
    
    # Method 2: Element-wise average
    fusion_methods['average'] = (text_proj + image_proj) / 2
    
    # Method 3: Weighted average (text heavy)
    fusion_methods['text_weighted'] = 0.7 * text_proj + 0.3 * image_proj
    
    # Method 4: Weighted average (image heavy)  
    fusion_methods['image_weighted'] = 0.3 * text_proj + 0.7 * image_proj
    
    # Method 5: Element-wise multiplication
    fusion_methods['multiplication'] = text_proj * image_proj
    
    # Method 6: Element-wise maximum
    fusion_methods['maximum'] = torch.max(text_proj, image_proj)
    
    # Method 7: Element-wise minimum
    fusion_methods['minimum'] = torch.min(text_proj, image_proj)
    
    # Compute metrics for each method
    method_metrics = {}
    for method_name, fused_emb in fusion_methods.items():
        metrics, _ = compute_unified_similarity_metrics(fused_emb)
        method_metrics[method_name] = {
            'shape': fused_emb.shape,
            'similarity_stats': metrics['similarity_stats'],
            'top_similarities': metrics['top_similarities']
        }
    
    return method_metrics

# Compute metrics for unified embeddings
if embedding_results:
    print("üìä Computing metrics for unified embeddings:\n")
    
    unified_metrics = {}
    
    for model_name, embeddings in embedding_results.items():
        print(f"üîç {model_name}:")
        
        unified_emb = embeddings['unified_embeddings']
        text_emb = embeddings['text_embeddings']  
        img_emb = embeddings['image_embeddings']
        
        # Compute unified embedding metrics
        metrics, similarity_matrix = compute_unified_similarity_metrics(unified_emb)
        
        print(f"   üìà Unified Embedding Similarity Stats:")
        stats = metrics['similarity_stats']
        print(f"      Mean similarity: {stats['mean']:.4f}")
        print(f"      Std similarity:  {stats['std']:.4f}")
        print(f"      Min similarity:  {stats['min']:.4f}")
        print(f"      Max similarity:  {stats['max']:.4f}")
        
        print(f"   üèÜ Top-K Average Similarities:")
        for k, sim in metrics['top_similarities'].items():
            print(f"      {k.replace('_', '-').title()}: {sim:.4f}")
        
        # Compare fusion methods
        print(f"   ? Comparing fusion methods:")
        fusion_comparison = compare_fusion_methods(text_emb, img_emb)
        
        for method, method_metrics in fusion_comparison.items():
            shape = method_metrics['shape']
            mean_sim = method_metrics['similarity_stats']['mean']
            print(f"      {method:<15}: {shape} ‚Üí mean_sim: {mean_sim:.4f}")
        
        unified_metrics[model_name] = {
            'unified_metrics': metrics,
            'fusion_comparison': fusion_comparison
        }
        print()
    
    # Create unified embedding comparison table
    print("üìã UNIFIED EMBEDDING COMPARISON TABLE:")
    print("=" * 100)
    print(f"{'Model':<20} {'Dimension':<12} {'Mean Sim':<10} {'Std Sim':<10} {'Max Sim':<10} {'Top-1 Avg':<10} {'Top-5 Avg':<10}")
    print("=" * 100)
    
    for model_name, metrics in unified_metrics.items():
        unified_shape = embedding_results[model_name]['unified_embeddings'].shape[1]
        stats = metrics['unified_metrics']['similarity_stats']
        top_sims = metrics['unified_metrics']['top_similarities']
        
        print(f"{model_name:<20} {unified_shape:<12} {stats['mean']:<10.4f} {stats['std']:<10.4f} "
              f"{stats['max']:<10.4f} {top_sims['avg_top_1']:<10.4f} {top_sims['avg_top_5']:<10.4f}")
    
    print("=" * 100)
    
    # Save unified metrics to file
    metrics_file = embeddings_dir / 'unified_embedding_metrics.pkl'
    with open(metrics_file, 'wb') as f:
        pickle.dump(unified_metrics, f)
    print(f"üíæ Saved unified embedding metrics to: {metrics_file}")
    
    # Also save fusion comparison
    fusion_file = embeddings_dir / 'fusion_method_comparison.pkl'
    fusion_comparison_all = {model: metrics['fusion_comparison'] for model, metrics in unified_metrics.items()}
    with open(fusion_file, 'wb') as f:
        pickle.dump(fusion_comparison_all, f)
    print(f"üíæ Saved fusion method comparison to: {fusion_file}")

else:
    print("‚ùå No unified embeddings available for computing metrics")

### 6.2 Load Embeddings from Saved Files

You can also load the embeddings later from the saved files for analysis or further processing.

In [None]:
def load_saved_embeddings(model_name, embeddings_dir='./embeddings', embedding_type='unified'):
    """
    Load previously saved embeddings for a model with focus on unified embeddings.
    
    Args:
        model_name: Name of the model
        embeddings_dir: Directory containing saved embeddings
        embedding_type: Type of embedding to prioritize ('unified', 'text', 'image', 'all')
    
    Returns:
        dict: Dictionary with loaded embeddings, IDs, and captions
    """
    embeddings_path = Path(embeddings_dir)
    model_dir = embeddings_path / model_name.lower().replace('-', '_')
    
    if not model_dir.exists():
        print(f"‚ùå No embeddings found for {model_name} at {model_dir}")
        return None
    
    try:
        # Load pickle file if available (preferred - contains all data)
        pickle_file = model_dir / 'validation_embeddings.pkl'
        if pickle_file.exists():
            with open(pickle_file, 'rb') as f:
                data = pickle.load(f)
            print(f"‚úÖ Loaded embeddings for {model_name} from pickle file")
            
            # Verify unified embedding information
            if 'fusion_info' in data:
                fusion = data['fusion_info']
                print(f"   ? Fusion method: {fusion['default_method']}")
                print(f"   üìè Unified dim: {fusion['unified_dim']}, Text: {fusion['original_text_dim']}, Image: {fusion['original_image_dim']}")
            
            # Verify data consistency
            unified_emb = data.get('unified_embeddings')
            text_emb = data.get('text_embeddings')
            img_emb = data.get('image_embeddings')
            captions = data['captions']
            sample_ids = data.get('sample_ids', list(range(len(captions))))
            
            if unified_emb is not None:
                assert len(captions) == len(sample_ids) == unified_emb.shape[0], \
                    f"Data inconsistency: captions={len(captions)}, ids={len(sample_ids)}, unified_emb={unified_emb.shape[0]}"
                print(f"   ‚úÖ Unified embedding pairing verified: {len(sample_ids)} consistent pairs")
            
            return data
        
        # Otherwise load from separate numpy files
        unified_emb_file = model_dir / 'unified_embeddings.npy'
        text_emb_file = model_dir / 'text_embeddings.npy'
        img_emb_file = model_dir / 'image_embeddings.npy'
        sample_ids_file = model_dir / 'sample_ids.npy'
        captions_file = model_dir / 'captions_with_ids.txt'
        
        # Check what's available
        available_files = [f for f in [unified_emb_file, text_emb_file, img_emb_file, sample_ids_file, captions_file] if f.exists()]
        
        if len(available_files) >= 3:  # At least embeddings, ids, and captions
            data = {'model_name': model_name}
            
            # Load unified embeddings if available
            if unified_emb_file.exists():
                data['unified_embeddings'] = np.load(unified_emb_file)
                print(f"‚úÖ Loaded unified embeddings: {data['unified_embeddings'].shape}")
                
                # Load other fusion variants if available
                for variant in ['concat', 'avg', 'multiply', 'max']:
                    variant_file = model_dir / f'unified_embeddings_{variant}.npy'
                    if variant_file.exists():
                        data[f'unified_embeddings_{variant}'] = np.load(variant_file)
            
            # Load original embeddings if available
            if text_emb_file.exists():
                data['text_embeddings'] = np.load(text_emb_file)
            if img_emb_file.exists():
                data['image_embeddings'] = np.load(img_emb_file)
            
            # Load metadata
            if sample_ids_file.exists():
                data['sample_ids'] = np.load(sample_ids_file).tolist()
            
            if captions_file.exists():
                captions = []
                with open(captions_file, 'r') as f:
                    lines = f.readlines()
                    for line in lines:
                        if line.startswith('#') or not line.strip():
                            continue
                        parts = line.strip().split(',', 1)
                        if len(parts) == 2:
                            caption = parts[1].replace('\\,', ',')  # Unescape commas
                            captions.append(caption)
                data['captions'] = captions
            
            # Add metadata
            if 'unified_embeddings' in data:
                data['fusion_info'] = {
                    'default_method': 'concatenation',
                    'unified_dim': data['unified_embeddings'].shape[1]
                }
            
            print(f"‚úÖ Loaded embeddings for {model_name} from separate files")
            print(f"   ‚úÖ Available: {list(data.keys())}")
            return data
        
        else:
            print(f"‚ùå Insufficient embedding files for {model_name}")
            missing_files = [f.name for f in [unified_emb_file, sample_ids_file, captions_file] if not f.exists()]
            print(f"   Missing critical files: {missing_files}")
            return None
            
    except Exception as e:
        print(f"‚ùå Error loading embeddings for {model_name}: {e}")
        import traceback
        traceback.print_exc()
        return None

def verify_unified_embedding_pairing(data):
    """
    Verify that unified embeddings are properly constructed and paired.
    
    Args:
        data: Dictionary with embeddings data
    
    Returns:
        bool: True if pairing is correct
    """
    if not data:
        return False
    
    try:
        unified_emb = data.get('unified_embeddings')
        text_emb = data.get('text_embeddings')
        img_emb = data.get('image_embeddings') 
        captions = data.get('captions', [])
        sample_ids = data.get('sample_ids', list(range(len(captions))))
        
        print(f"üîç Unified embedding verification for {data.get('model_name', 'Unknown')}:")
        
        if unified_emb is not None:
            print(f"   Unified embeddings: {unified_emb.shape}")
        if text_emb is not None:
            print(f"   Text embeddings: {text_emb.shape}")
        if img_emb is not None:
            print(f"   Image embeddings: {img_emb.shape}")
        print(f"   Captions: {len(captions)}")
        print(f"   Sample IDs: {len(sample_ids)}")
        
        # Check unified embedding consistency
        if unified_emb is not None:
            unified_consistent = (unified_emb.shape[0] == len(captions) == len(sample_ids))
            print(f"   ‚úÖ Unified embedding consistency: {unified_consistent}")
            
            # Check if unified dimension makes sense
            if text_emb is not None and img_emb is not None:
                expected_concat_dim = text_emb.shape[1] + img_emb.shape[1]
                is_concatenated = (unified_emb.shape[1] == expected_concat_dim)
                print(f"   üîó Appears to be concatenated: {is_concatenated} ({unified_emb.shape[1]} vs {expected_concat_dim})")
            
            # Show fusion information if available
            if 'fusion_info' in data:
                fusion = data['fusion_info']
                print(f"   üìã Fusion method: {fusion.get('default_method', 'unknown')}")
                print(f"   üìè Dimensions: {fusion.get('unified_dim', 'unknown')}D unified")
        
        # Check ID uniqueness
        ids_unique = len(set(sample_ids)) == len(sample_ids)
        print(f"   ‚úÖ Sample IDs unique: {ids_unique}")
        
        # Sample some unified embeddings
        if unified_emb is not None and len(captions) > 0:
            print(f"   üìã Sample unified embeddings:")
            for i in [0, len(captions)//2, -1]:
                if i < len(captions):
                    idx = i if i >= 0 else len(captions) + i
                    emb_norm = np.linalg.norm(unified_emb[idx])
                    print(f"      [{idx}] ID:{sample_ids[idx]} norm:{emb_norm:.4f} -> \"{captions[idx][:40]}...\"")
        
        return unified_emb is not None and unified_consistent and ids_unique
        
    except Exception as e:
        print(f"   ‚ùå Error during verification: {e}")
        return False

# Example: Load unified embeddings from saved files
print("üìÇ Available embedding directories:")
if embeddings_dir.exists():
    for item in embeddings_dir.iterdir():
        if item.is_dir():
            print(f"   üìÅ {item.name}")
            
    # Try to load an example (if any exist)
    model_dirs = [d for d in embeddings_dir.iterdir() if d.is_dir()]
    if model_dirs:
        example_model = model_dirs[0].name.replace('_', '-').upper()
        print(f"\nüîç Example: Loading unified embeddings for {example_model}")
        loaded_data = load_saved_embeddings(example_model, embeddings_dir, embedding_type='unified')
        
        if loaded_data:
            verify_unified_embedding_pairing(loaded_data)
            
            # Show available fusion methods
            if 'unified_embeddings_concat' in loaded_data:
                print(f"\n   üîó Available fusion variants:")
                for key in loaded_data.keys():
                    if key.startswith('unified_embeddings_') and isinstance(loaded_data[key], np.ndarray):
                        variant = key.replace('unified_embeddings_', '')
                        shape = loaded_data[key].shape
                        print(f"      {variant}: {shape}")
    else:
        print("   (No model directories found)")
else:
    print("   (Embeddings directory does not exist yet)")

### 6.3 Demonstrate Proper Text-Image Pairing

Let's show how to properly match text and image embeddings using the sample IDs.

In [None]:
def demonstrate_pairing(embedding_data, num_examples=5):
    """
    Demonstrate how text and image embeddings are properly paired using sample IDs.
    
    Args:
        embedding_data: Dictionary with embeddings and metadata
        num_examples: Number of examples to show
    """
    if not embedding_data:
        print("‚ùå No embedding data provided")
        return
    
    text_emb = embedding_data['text_embeddings']
    img_emb = embedding_data['image_embeddings']
    captions = embedding_data['captions']
    sample_ids = embedding_data.get('sample_ids', list(range(len(captions))))
    model_name = embedding_data.get('model_name', 'Unknown')
    
    print(f"üîó Demonstrating pairing for {model_name}:")
    print(f"   Total pairs: {len(sample_ids)}")
    print(f"   Text embedding dim: {text_emb.shape[1]}")
    print(f"   Image embedding dim: {img_emb.shape[1]}")
    print()
    
    # Show some examples
    indices = np.linspace(0, len(sample_ids)-1, num_examples, dtype=int)
    
    for i, idx in enumerate(indices):
        sample_id = sample_ids[idx]
        caption = captions[idx]
        text_embedding = text_emb[idx]
        image_embedding = img_emb[idx]
        
        print(f"üìù Example {i+1} (Index {idx}):")
        print(f"   Sample ID: {sample_id}")
        print(f"   Caption: \"{caption[:80]}{'...' if len(caption) > 80 else ''}\"")
        print(f"   Text embedding: shape={text_embedding.shape}, norm={np.linalg.norm(text_embedding):.4f}")
        print(f"   Image embedding: shape={image_embedding.shape}, norm={np.linalg.norm(image_embedding):.4f}")
        
        # Compute similarity between paired text and image
        similarity = np.dot(text_embedding, image_embedding) / (np.linalg.norm(text_embedding) * np.linalg.norm(image_embedding))
        print(f"   Text-Image similarity: {similarity:.4f}")
        print()

def find_most_similar_pairs(embedding_data, top_k=5):
    """
    Find the most similar text-image pairs (should be the diagonal if properly paired).
    
    Args:
        embedding_data: Dictionary with embeddings and metadata
        top_k: Number of top similar pairs to show
    """
    if not embedding_data:
        print("‚ùå No embedding data provided")
        return
    
    text_emb = embedding_data['text_embeddings']
    img_emb = embedding_data['image_embeddings']
    captions = embedding_data['captions']
    sample_ids = embedding_data.get('sample_ids', list(range(len(captions))))
    model_name = embedding_data.get('model_name', 'Unknown')
    
    print(f"üîç Finding most similar text-image pairs for {model_name}:")
    
    # Normalize embeddings for cosine similarity
    text_emb_norm = text_emb / np.linalg.norm(text_emb, axis=1, keepdims=True)
    img_emb_norm = img_emb / np.linalg.norm(img_emb, axis=1, keepdims=True)
    
    # Compute similarity matrix
    similarity_matrix = text_emb_norm @ img_emb_norm.T
    
    # Get diagonal similarities (correct pairs)
    diagonal_similarities = np.diag(similarity_matrix)
    
    # Find top-k most similar pairs overall
    flat_similarities = similarity_matrix.flatten()
    top_indices = np.argsort(flat_similarities)[-top_k:][::-1]
    
    print(f"   üìä Diagonal similarity stats (correct pairs):")
    print(f"      Mean: {diagonal_similarities.mean():.4f}")
    print(f"      Std:  {diagonal_similarities.std():.4f}")
    print(f"      Min:  {diagonal_similarities.min():.4f}")
    print(f"      Max:  {diagonal_similarities.max():.4f}")
    print()
    
    print(f"   üèÜ Top {top_k} most similar pairs overall:")
    for rank, flat_idx in enumerate(top_indices):
        text_idx = flat_idx // similarity_matrix.shape[1]
        img_idx = flat_idx % similarity_matrix.shape[1]
        similarity = similarity_matrix[text_idx, img_idx]
        
        is_correct_pair = (text_idx == img_idx)
        pair_type = "‚úÖ CORRECT" if is_correct_pair else "‚ùå INCORRECT"
        
        print(f"      {rank+1}. Text[{text_idx}] <-> Image[{img_idx}]: {similarity:.4f} {pair_type}")
        if not is_correct_pair:
            print(f"         Text: \"{captions[text_idx][:60]}...\"")
        
    # Check if top pairs are mostly diagonal (good sign)
    correct_in_top = sum(1 for flat_idx in top_indices if (flat_idx // similarity_matrix.shape[1]) == (flat_idx % similarity_matrix.shape[1]))
    print(f"\n   üìà {correct_in_top}/{top_k} top pairs are correctly paired")

# Test pairing demonstration if we have embedding results
if 'embedding_results' in locals() and embedding_results:
    print("üß™ Testing pairing demonstration with generated embeddings:\n")
    
    # Pick the first available model
    model_name = list(embedding_results.keys())[0]
    demo_data = {
        'text_embeddings': embedding_results[model_name]['text_embeddings'].numpy(),
        'image_embeddings': embedding_results[model_name]['image_embeddings'].numpy(),
        'captions': embedding_results[model_name]['captions'],
        'sample_ids': embedding_results[model_name]['sample_ids'],
        'model_name': model_name
    }
    
    demonstrate_pairing(demo_data, num_examples=3)
    find_most_similar_pairs(demo_data, top_k=10)
    
else:
    print("‚ÑπÔ∏è  No embedding results available for pairing demonstration")
    print("   Run the embedding generation cells first to see pairing examples")

In [None]:
# üöÄ Execute the complete unified embedding generation pipeline
print("=" * 80)
print("üöÄ STARTING UNIFIED EMBEDDING GENERATION PIPELINE")
print("=" * 80)

try:
    # Generate unified embeddings for all models
    print(f"\nüìä Processing {len(val_df)} validation samples...")
    print(f"üìÅ Saving to: {embeddings_dir}")
    
    # Track overall statistics
    total_models = len(MODELS_TO_TEST)
    successful_models = 0
    failed_models = []
    fusion_stats = {}
    
    for i, model_name in enumerate(MODELS_TO_TEST, 1):
        print(f"\n{'='*60}")
        print(f"üîÑ [{i}/{total_models}] Processing {model_name}")
        print(f"{'='*60}")
        
        try:
            # Generate unified embeddings
            embedding_data = generate_embeddings_for_model(model_name, val_df, embeddings_dir)
            
            if embedding_data and 'unified_embeddings' in embedding_data:
                successful_models += 1
                
                # Collect fusion statistics
                if 'fusion_info' in embedding_data:
                    fusion = embedding_data['fusion_info']
                    model_key = model_name.lower().replace('-', '_')
                    fusion_stats[model_key] = {
                        'method': fusion['default_method'],
                        'unified_dim': fusion['unified_dim'],
                        'text_dim': fusion.get('original_text_dim', 'unknown'),
                        'image_dim': fusion.get('original_image_dim', 'unknown'),
                        'compression_ratio': fusion.get('compression_ratio', 'unknown')
                    }
                
                print(f"‚úÖ Successfully generated unified embeddings for {model_name}")
                
                # Quick verification
                verify_unified_embedding_pairing(embedding_data)
                
            else:
                failed_models.append(model_name)
                print(f"‚ùå Failed to generate unified embeddings for {model_name}")
                
        except Exception as e:
            failed_models.append(model_name)
            print(f"‚ùå Exception processing {model_name}: {e}")
            import traceback
            traceback.print_exc()
    
    # Final summary
    print("\n" + "=" * 80)
    print("üéØ UNIFIED EMBEDDING GENERATION SUMMARY")
    print("=" * 80)
    print(f"‚úÖ Successful models: {successful_models}/{total_models}")
    print(f"‚ùå Failed models: {len(failed_models)}")
    
    if failed_models:
        print(f"   Failed: {', '.join(failed_models)}")
    
    if fusion_stats:
        print(f"\nüìä FUSION STATISTICS:")
        for model, stats in fusion_stats.items():
            print(f"   üîó {model.upper()}:")
            print(f"      Method: {stats['method']}")
            print(f"      Dimensions: {stats['text_dim']}D text + {stats['image_dim']}D image ‚Üí {stats['unified_dim']}D unified")
            if stats['compression_ratio'] != 'unknown':
                print(f"      Compression: {stats['compression_ratio']:.2f}x")
    
    # Show embedding directory structure
    print(f"\nüìÅ SAVED EMBEDDING STRUCTURE:")
    if embeddings_dir.exists():
        for model_dir in sorted(embeddings_dir.iterdir()):
            if model_dir.is_dir():
                print(f"   üìÇ {model_dir.name}/")
                for file in sorted(model_dir.iterdir()):
                    if file.is_file():
                        size_kb = file.stat().st_size / 1024
                        print(f"      üìÑ {file.name} ({size_kb:.1f} KB)")
    
    print(f"\nüéâ Unified embedding generation completed!")
    print(f"üìÇ All embeddings saved to: {embeddings_dir}")
    
except Exception as e:
    print(f"\nüí• Pipeline failed with error: {e}")
    import traceback
    traceback.print_exc()

# üìù Text Embedding Pipeline for Individual Models

This section provides a pipeline to generate text embeddings for arbitrary text inputs using each individual model. Perfect for testing and exploring model capabilities with custom text.

In [None]:
def create_text_embedding_pipeline(model_name):
    """
    Create a text embedding pipeline for a specific model.
    
    Args:
        model_name: Name of the model ('CLIPP-SciBERT', 'CLIPP-DistilBERT', 'MobileCLIP', 'BLIP')
    
    Returns:
        function: A pipeline function that takes text and returns embeddings
    """
    print(f"üîß Creating text embedding pipeline for {model_name}...")
    
    try:
        # Load the model
        if model_name == 'CLIPP-SciBERT':
            model, tokenizer, processor, device = load_clipp_scibert_model()
            
            def pipeline(text_input):
                """Generate text embedding using CLIPP-SciBERT."""
                if isinstance(text_input, list):
                    texts = text_input
                else:
                    texts = [text_input]
                
                model.eval()
                embeddings = []
                
                with torch.no_grad():
                    for text in texts:
                        # Tokenize text
                        text_tokens = tokenizer(text, padding=True, truncation=True, 
                                              return_tensors="pt", max_length=512).to(device)
                        
                        # Get text embedding
                        text_features = model.encode_text(text_tokens['input_ids'], text_tokens['attention_mask'])
                        text_features = F.normalize(text_features, p=2, dim=1)
                        
                        embeddings.append(text_features.cpu().numpy())
                
                result = np.vstack(embeddings) if len(embeddings) > 1 else embeddings[0]
                return result.squeeze() if len(texts) == 1 else result
                
        elif model_name == 'CLIPP-DistilBERT':
            model, tokenizer, processor, device = load_clipp_distilbert_model()
            
            def pipeline(text_input):
                """Generate text embedding using CLIPP-DistilBERT."""
                if isinstance(text_input, list):
                    texts = text_input
                else:
                    texts = [text_input]
                
                model.eval()
                embeddings = []
                
                with torch.no_grad():
                    for text in texts:
                        # Tokenize text
                        text_tokens = tokenizer(text, padding=True, truncation=True, 
                                              return_tensors="pt", max_length=512).to(device)
                        
                        # Get text embedding
                        text_features = model.encode_text(text_tokens['input_ids'], text_tokens['attention_mask'])
                        text_features = F.normalize(text_features, p=2, dim=1)
                        
                        embeddings.append(text_features.cpu().numpy())
                
                result = np.vstack(embeddings) if len(embeddings) > 1 else embeddings[0]
                return result.squeeze() if len(texts) == 1 else result
                
        elif model_name == 'MobileCLIP':
            model, tokenizer, processor, device = load_apple_mobileclip_model()
            
            def pipeline(text_input):
                """Generate text embedding using MobileCLIP."""
                if isinstance(text_input, list):
                    texts = text_input
                else:
                    texts = [text_input]
                
                model.eval()
                embeddings = []
                
                with torch.no_grad():
                    for text in texts:
                        # Tokenize text using open_clip tokenizer
                        text_tokens = open_clip.tokenize([text]).to(device)
                        
                        # Get text embedding
                        text_features = model.encode_text(text_tokens)
                        text_features = F.normalize(text_features, p=2, dim=1)
                        
                        embeddings.append(text_features.cpu().numpy())
                
                result = np.vstack(embeddings) if len(embeddings) > 1 else embeddings[0]
                return result.squeeze() if len(texts) == 1 else result
                
        elif model_name == 'BLIP':
            model, processor, device = load_blip_model()
            
            def pipeline(text_input):
                """Generate text embedding using BLIP."""
                if isinstance(text_input, list):
                    texts = text_input
                else:
                    texts = [text_input]
                
                model.eval()
                embeddings = []
                
                with torch.no_grad():
                    for text in texts:
                        # Process text
                        inputs = processor(text=[text], return_tensors="pt", 
                                         padding=True, truncation=True, max_length=512).to(device)
                        
                        # Get text embedding
                        text_embeds = model.get_text_features(**inputs)
                        text_embeds = F.normalize(text_embeds, p=2, dim=1)
                        
                        embeddings.append(text_embeds.cpu().numpy())
                
                result = np.vstack(embeddings) if len(embeddings) > 1 else embeddings[0]
                return result.squeeze() if len(texts) == 1 else result
                
        else:
            raise ValueError(f"Unknown model: {model_name}")
            
        print(f"‚úÖ Text embedding pipeline created for {model_name}")
        return pipeline
        
    except Exception as e:
        print(f"‚ùå Failed to create pipeline for {model_name}: {e}")
        import traceback
        traceback.print_exc()
        return None

def create_all_text_pipelines():
    """
    Create text embedding pipelines for all available models.
    
    Returns:
        dict: Dictionary mapping model names to their pipeline functions
    """
    print("üè≠ Creating text embedding pipelines for all models...")
    pipelines = {}
    
    for model_name in MODELS_TO_TEST:
        print(f"\n{'='*50}")
        pipeline = create_text_embedding_pipeline(model_name)
        if pipeline:
            pipelines[model_name] = pipeline
            print(f"‚úÖ Pipeline ready for {model_name}")
        else:
            print(f"‚ùå Failed to create pipeline for {model_name}")
    
    print(f"\nüéâ Created {len(pipelines)}/{len(MODELS_TO_TEST)} text embedding pipelines")
    return pipelines

In [None]:
def test_text_embeddings_with_examples():
    """
    Test text embedding pipelines with example material science texts.
    """
    print("üß™ Testing text embedding pipelines with example materials...")
    
    # Example material science texts
    example_texts = [
        "Silicon dioxide thin film with high dielectric constant",
        "Graphene-based composite material for energy storage applications",
        "Perovskite solar cell with enhanced efficiency and stability",
        "Titanium alloy with superior mechanical properties",
        "Carbon nanotube reinforced polymer matrix composite",
        "Aluminum oxide nanoparticles for catalytic applications",
        "Copper-zinc alloy with antimicrobial properties",
        "Lithium-ion battery cathode material with high capacity"
    ]
    
    # Create all pipelines
    pipelines = create_all_text_pipelines()
    
    if not pipelines:
        print("‚ùå No pipelines available for testing")
        return
    
    print(f"\nüî¨ Testing with {len(example_texts)} example texts...")
    print(f"üìã Available pipelines: {list(pipelines.keys())}")
    
    # Test each pipeline
    results = {}
    for model_name, pipeline in pipelines.items():
        print(f"\n{'='*60}")
        print(f"üß™ Testing {model_name} pipeline")
        print(f"{'='*60}")
        
        try:
            # Test single text
            single_text = example_texts[0]
            print(f"üìù Single text: \"{single_text[:50]}...\"")
            
            single_embedding = pipeline(single_text)
            print(f"‚úÖ Single embedding shape: {single_embedding.shape}")
            print(f"   Embedding norm: {np.linalg.norm(single_embedding):.4f}")
            print(f"   Sample values: [{single_embedding[0]:.4f}, {single_embedding[1]:.4f}, ..., {single_embedding[-1]:.4f}]")
            
            # Test batch processing
            batch_texts = example_texts[:3]
            print(f"\nüìù Batch processing {len(batch_texts)} texts...")
            
            batch_embeddings = pipeline(batch_texts)
            print(f"‚úÖ Batch embeddings shape: {batch_embeddings.shape}")
            print(f"   Individual norms: {[f'{np.linalg.norm(emb):.4f}' for emb in batch_embeddings]}")
            
            # Compute similarities within batch
            if len(batch_embeddings.shape) > 1 and batch_embeddings.shape[0] > 1:
                similarities = np.dot(batch_embeddings, batch_embeddings.T)
                print(f"   Similarity matrix diagonal: {np.diag(similarities)}")
                print(f"   Off-diagonal similarities: {similarities[0,1]:.4f}, {similarities[0,2]:.4f}, {similarities[1,2]:.4f}")
            
            # Store results
            results[model_name] = {
                'single_embedding': single_embedding,
                'batch_embeddings': batch_embeddings,
                'embedding_dim': single_embedding.shape[-1],
                'success': True
            }
            
            print(f"‚úÖ {model_name} pipeline test successful")
            
        except Exception as e:
            print(f"‚ùå {model_name} pipeline test failed: {e}")
            results[model_name] = {'success': False, 'error': str(e)}
            import traceback
            traceback.print_exc()
    
    # Summary
    print(f"\n{'='*80}")
    print("üéØ TEXT EMBEDDING PIPELINE TEST SUMMARY")
    print(f"{'='*80}")
    
    successful_models = [name for name, result in results.items() if result.get('success', False)]
    failed_models = [name for name, result in results.items() if not result.get('success', False)]
    
    print(f"‚úÖ Successful models: {len(successful_models)}/{len(results)}")
    print(f"‚ùå Failed models: {len(failed_models)}")
    
    if successful_models:
        print(f"\nüìä EMBEDDING DIMENSIONS:")
        for model_name in successful_models:
            dim = results[model_name]['embedding_dim']
            print(f"   üîó {model_name}: {dim}D")
    
    if failed_models:
        print(f"\n‚ùå Failed models: {', '.join(failed_models)}")
        for model_name in failed_models:
            error = results[model_name].get('error', 'Unknown error')
            print(f"   {model_name}: {error}")
    
    print(f"\nüéâ Text embedding pipeline testing completed!")
    return results

def interactive_text_embedding_demo(pipelines=None):
    """
    Interactive demo for testing text embeddings with custom input.
    
    Args:
        pipelines: Dict of model pipelines (will create if None)
    """
    if pipelines is None:
        print("üîß Creating pipelines for interactive demo...")
        pipelines = create_all_text_pipelines()
    
    if not pipelines:
        print("‚ùå No pipelines available for demo")
        return
    
    print(f"\nüéÆ INTERACTIVE TEXT EMBEDDING DEMO")
    print(f"{'='*50}")
    print(f"Available models: {', '.join(pipelines.keys())}")
    print(f"{'='*50}")
    
    # Demo texts (can be customized)
    demo_texts = [
        "High-performance lithium-ion battery material",
        "Transparent conducting oxide thin film",
        "Magnetic nanoparticles for biomedical applications",
        "Flexible organic photovoltaic device"
    ]
    
    print(f"üìù Demo texts:")
    for i, text in enumerate(demo_texts, 1):
        print(f"   {i}. {text}")
    
    # Generate embeddings for all models
    print(f"\nüîÑ Generating embeddings for all models...")
    
    all_embeddings = {}
    for model_name, pipeline in pipelines.items():
        print(f"\nüîß Processing with {model_name}...")
        try:
            embeddings = pipeline(demo_texts)
            all_embeddings[model_name] = embeddings
            print(f"‚úÖ Generated embeddings: {embeddings.shape}")
        except Exception as e:
            print(f"‚ùå Failed: {e}")
    
    # Analyze similarities
    if len(all_embeddings) > 1:
        print(f"\nüìä CROSS-MODEL SIMILARITY ANALYSIS")
        print(f"{'='*50}")
        
        model_names = list(all_embeddings.keys())
        for i, text_idx in enumerate([0, 1]):  # Analyze first two texts
            print(f"\nüìù Text {text_idx+1}: \"{demo_texts[text_idx]}\"")
            
            # Get embeddings for this text from all models
            text_embeddings = {}
            for model_name in model_names:
                if model_name in all_embeddings:
                    emb = all_embeddings[model_name]
                    if len(emb.shape) > 1:
                        text_embeddings[model_name] = emb[text_idx]
                    else:
                        text_embeddings[model_name] = emb
            
            # Compare embeddings between models
            if len(text_embeddings) > 1:
                print(f"   üìè Embedding dimensions:")
                for model_name, emb in text_embeddings.items():
                    print(f"      {model_name}: {emb.shape} (norm: {np.linalg.norm(emb):.4f})")
    
    return all_embeddings

In [None]:
def save_text_embedding_pipeline_results(results, save_dir='./text_embeddings'):
    """
    Save text embedding pipeline results to files.
    
    Args:
        results: Results from test_text_embeddings_with_examples()
        save_dir: Directory to save results
    """
    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True)
    
    print(f"üíæ Saving text embedding results to {save_path}")
    
    successful_results = {name: data for name, data in results.items() if data.get('success', False)}
    
    if not successful_results:
        print("‚ùå No successful results to save")
        return
    
    # Save individual model results
    for model_name, data in successful_results.items():
        model_dir = save_path / model_name.lower().replace('-', '_')
        model_dir.mkdir(exist_ok=True)
        
        # Save embeddings
        if 'single_embedding' in data:
            np.save(model_dir / 'single_text_embedding.npy', data['single_embedding'])
        if 'batch_embeddings' in data:
            np.save(model_dir / 'batch_text_embeddings.npy', data['batch_embeddings'])
        
        # Save metadata
        metadata = {
            'model_name': model_name,
            'embedding_dim': data['embedding_dim'],
            'single_shape': data['single_embedding'].shape if 'single_embedding' in data else None,
            'batch_shape': data['batch_embeddings'].shape if 'batch_embeddings' in data else None
        }
        
        import json
        with open(model_dir / 'text_embedding_metadata.json', 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"‚úÖ Saved {model_name} results to {model_dir}")
    
    # Save summary
    summary = {
        'successful_models': list(successful_results.keys()),
        'failed_models': [name for name, data in results.items() if not data.get('success', False)],
        'embedding_dimensions': {name: data['embedding_dim'] for name, data in successful_results.items()},
        'timestamp': str(pd.Timestamp.now())
    }
    
    with open(save_path / 'text_embedding_summary.json', 'w') as f:
        json.dump(summary, f, indent=2)
    
    print(f"üíæ Saved summary to {save_path / 'text_embedding_summary.json'}")
    return save_path

def custom_text_embedding_generator(text_input, models_to_use=None):
    """
    Generate embeddings for custom text input using specified models.
    
    Args:
        text_input: Single string or list of strings
        models_to_use: List of model names to use (None for all available)
    
    Returns:
        dict: Model name -> embeddings mapping
    """
    print(f"üî§ Generating embeddings for custom text input...")
    
    if isinstance(text_input, str):
        print(f"üìù Input text: \"{text_input[:100]}{'...' if len(text_input) > 100 else ''}\"")
    else:
        print(f"üìù Input: {len(text_input)} texts")
        for i, text in enumerate(text_input[:3]):  # Show first 3
            print(f"   {i+1}. \"{text[:80]}{'...' if len(text) > 80 else ''}\"")
        if len(text_input) > 3:
            print(f"   ... and {len(text_input)-3} more")
    
    # Create pipelines
    available_pipelines = create_all_text_pipelines()
    
    if models_to_use:
        pipelines = {name: pipeline for name, pipeline in available_pipelines.items() 
                    if name in models_to_use}
        missing = set(models_to_use) - set(available_pipelines.keys())
        if missing:
            print(f"‚ö†Ô∏è  Requested models not available: {missing}")
    else:
        pipelines = available_pipelines
    
    if not pipelines:
        print("‚ùå No pipelines available")
        return {}
    
    print(f"üîß Using models: {list(pipelines.keys())}")
    
    # Generate embeddings
    results = {}
    for model_name, pipeline in pipelines.items():
        print(f"\nüîÑ Generating embeddings with {model_name}...")
        try:
            embeddings = pipeline(text_input)
            results[model_name] = embeddings
            
            if isinstance(text_input, str):
                print(f"‚úÖ Generated embedding: {embeddings.shape} (norm: {np.linalg.norm(embeddings):.4f})")
            else:
                print(f"‚úÖ Generated embeddings: {embeddings.shape}")
                norms = [np.linalg.norm(emb) for emb in embeddings]
                print(f"   Norms: [{norms[0]:.4f}, {norms[1]:.4f}, ..., {norms[-1]:.4f}]")
            
        except Exception as e:
            print(f"‚ùå Failed with {model_name}: {e}")
    
    print(f"\nüéâ Generated embeddings with {len(results)}/{len(pipelines)} models")
    return results

## üöÄ Execute Text Embedding Pipeline

Run the text embedding pipeline to test all models with example materials science texts.

In [None]:
# üß™ Test text embedding pipelines with predefined examples
print("=" * 80)
print("üß™ TESTING TEXT EMBEDDING PIPELINES")
print("=" * 80)

# Run comprehensive test
test_results = test_text_embeddings_with_examples()

# Save results
if test_results:
    saved_path = save_text_embedding_pipeline_results(test_results)
    print(f"\nüíæ Results saved to: {saved_path}")

print("\nüéâ Text embedding pipeline testing completed!")

## üéÆ Custom Text Embedding Examples

Use these examples to generate embeddings for your own custom texts.

In [None]:
# üéØ Example 1: Single custom text
custom_text = "Advanced polymer composite with carbon fiber reinforcement for aerospace applications"

print("üî§ Generating embeddings for custom text:")
print(f"üìù Text: \"{custom_text}\"")

# Generate embeddings for all models
custom_results = custom_text_embedding_generator(custom_text)

# Display results
if custom_results:
    print(f"\nüìä EMBEDDING ANALYSIS:")
    for model_name, embedding in custom_results.items():
        norm = np.linalg.norm(embedding)
        print(f"   üîó {model_name}: {embedding.shape} (norm: {norm:.4f})")
        print(f"      Sample values: [{embedding[0]:.4f}, {embedding[1]:.4f}, ..., {embedding[-1]:.4f}]")
    
    # Compare similarities if multiple models
    if len(custom_results) > 1:
        print(f"\nüîÑ CROSS-MODEL SIMILARITIES:")
        model_names = list(custom_results.keys())
        for i in range(len(model_names)):
            for j in range(i+1, len(model_names)):
                model1, model2 = model_names[i], model_names[j]
                emb1, emb2 = custom_results[model1], custom_results[model2]
                
                # Normalize embeddings for fair comparison
                emb1_norm = emb1 / np.linalg.norm(emb1)
                emb2_norm = emb2 / np.linalg.norm(emb2)
                
                similarity = np.dot(emb1_norm, emb2_norm)
                print(f"   {model1} ‚Üî {model2}: {similarity:.4f}")
else:
    print("‚ùå No embeddings generated")

In [None]:
# üéØ Example 2: Batch processing multiple texts
batch_texts = [
    "Silicon carbide semiconductor for high-power electronics",
    "Organic photovoltaic cell with improved efficiency",
    "Magnetic nanoparticles for targeted drug delivery",
    "Flexible conducting polymer for wearable devices",
    "Ceramic matrix composite for high-temperature applications"
]

print("üî§ Generating embeddings for batch of texts:")
print(f"üìù Processing {len(batch_texts)} texts:")
for i, text in enumerate(batch_texts, 1):
    print(f"   {i}. {text}")

# Generate embeddings for all models
batch_results = custom_text_embedding_generator(batch_texts)

# Analyze batch results
if batch_results:
    print(f"\nüìä BATCH EMBEDDING ANALYSIS:")
    for model_name, embeddings in batch_results.items():
        print(f"\nüîó {model_name}:")
        print(f"   Shape: {embeddings.shape}")
        
        # Compute statistics
        norms = [np.linalg.norm(emb) for emb in embeddings]
        print(f"   Norms: min={min(norms):.4f}, max={max(norms):.4f}, mean={np.mean(norms):.4f}")
        
        # Compute pairwise similarities within batch
        if embeddings.shape[0] > 1:
            # Normalize embeddings
            normalized_embs = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
            similarity_matrix = np.dot(normalized_embs, normalized_embs.T)
            
            # Show similarity statistics
            off_diagonal = similarity_matrix[np.triu_indices_from(similarity_matrix, k=1)]
            print(f"   Similarities: min={off_diagonal.min():.4f}, max={off_diagonal.max():.4f}, mean={off_diagonal.mean():.4f}")
            
            # Show most similar pair
            max_sim_idx = np.unravel_index(np.argmax(similarity_matrix - np.eye(len(similarity_matrix))), 
                                         similarity_matrix.shape)
            max_sim_val = similarity_matrix[max_sim_idx]
            print(f"   Most similar: Text {max_sim_idx[0]+1} ‚Üî Text {max_sim_idx[1]+1} (similarity: {max_sim_val:.4f})")
            
    # Compare embedding patterns across models
    if len(batch_results) > 1:
        print(f"\nüîÑ CROSS-MODEL COMPARISON:")
        model_names = list(batch_results.keys())
        
        # Compare embedding dimensions
        dims = {name: embs.shape[1] for name, embs in batch_results.items()}
        print(f"   Embedding dimensions: {dims}")
        
        # Compare first text across models
        print(f"\n   First text across models:")
        first_text_embeddings = {name: embs[0] for name, embs in batch_results.items()}
        
        for i, model1 in enumerate(model_names):
            for model2 in model_names[i+1:]:
                emb1 = first_text_embeddings[model1]
                emb2 = first_text_embeddings[model2]
                
                # Normalize for comparison
                emb1_norm = emb1 / np.linalg.norm(emb1)
                emb2_norm = emb2 / np.linalg.norm(emb2)
                
                similarity = np.dot(emb1_norm, emb2_norm)
                print(f"      {model1} ‚Üî {model2}: {similarity:.4f}")
else:
    print("‚ùå No batch embeddings generated")

In [None]:
# üéØ Example 3: Interactive custom text input
# You can modify these texts to test with your own materials

def quick_text_embedding_demo():
    """Quick demo function for testing custom texts."""
    print("üéÆ QUICK TEXT EMBEDDING DEMO")
    print("=" * 50)
    
    # Define your custom texts here
    your_texts = [
        "PUT YOUR CUSTOM TEXT HERE",
        "Quantum dots for display applications",
        "Biodegradable polymer for medical implants",
        "Superconducting material at room temperature"
    ]
    
    # You can also test with a single text
    single_test_text = "Graphene oxide membrane for water filtration"
    
    print(f"üìù Testing single text:")
    print(f"   \"{single_test_text}\"")
    
    # Test single text
    single_results = custom_text_embedding_generator(single_test_text, 
                                                   models_to_use=['CLIPP-SciBERT', 'MobileCLIP'])  # Specify models if desired
    
    if single_results:
        print(f"\n‚úÖ Generated embeddings for {len(single_results)} models")
        for model, emb in single_results.items():
            print(f"   {model}: {emb.shape} (norm: {np.linalg.norm(emb):.4f})")
    
    print(f"\nüìù Testing batch texts:")
    # Filter out placeholder text
    real_texts = [text for text in your_texts if not text.startswith("PUT YOUR")]
    
    if real_texts:
        batch_results = custom_text_embedding_generator(real_texts)
        
        if batch_results:
            print(f"\n‚úÖ Generated batch embeddings for {len(batch_results)} models")
            for model, embs in batch_results.items():
                print(f"   {model}: {embs.shape}")
        
        return single_results, batch_results
    else:
        print("   (Modify 'your_texts' list above to test custom inputs)")
        return single_results, None

# Run the demo
demo_single, demo_batch = quick_text_embedding_demo()

print(f"\nüí° TIP: To test your own texts:")
print(f"   1. Modify the 'your_texts' list in the cell above")
print(f"   2. Replace 'single_test_text' with your text")
print(f"   3. Re-run the cell")
print(f"   4. Use 'models_to_use' parameter to test specific models only")