# Subset Selection Notebook 2: Embedding Generation

## Overview
This notebook is the second step in the **Subset Selection Pipeline**. It focuses on generating high-quality embeddings from the formatted text data prepared in Notebook 1.

## Purpose in Subset Selection
Embeddings are vector representations of text that capture semantic meaning. In subset selection, we need embeddings to:
1. **Measure Similarity**: Calculate how similar different data samples are to each other
2. **Enable Facility Location**: The subset selection algorithm uses embeddings to identify diverse, representative samples
3. **Preserve Semantic Information**: Ensure selected subsets maintain the semantic diversity of the original dataset

## Output
- **embeddings.h5**: Merged embedding file containing vector representations of all samples
- Used in Notebook 3 for subset selection

In [None]:
# Import from Notebook 1 using %run magic command
# This executes the entire Notebook 1 in the current namespace
%run "data_preparation_and_config.ipynb"

# Additional imports needed for embedding generation
from multiprocessing import Pool
from typing import Optional
import logging

# Third Party Imports (not in Notebook 1)
import h5py
from transformers import AutoModel, AutoTokenizer
import torch.nn.functional as F
from dataclasses import dataclass, field
from typing import TypedDict, Union, List, Dict, Optional, Any

# Set up logger for this notebook
logger = logging.getLogger(__name__)

print("✅ Successfully imported from Notebook 1!")
print(f"   • config: {type(config).__name__ if 'config' in locals() else 'Not defined'}")
print(f"   • dataset: {len(dataset) if 'dataset' in locals() and dataset else 'None'} samples")

print("\n📦 Notebook 2 Components Loading...")
print("=" * 60)


### Arctic Encoder Model Configuration
Defines the configuration structure and settings for the Snowflake Arctic embedding model.
**Classes Defined:**
1. **`ModelConfig`** (TypedDict): Schema for model-specific settings
2. **`ArcticEncoderConfig`** (Dataclass): Internal configuration for the encoder instance
3. **`MODEL_CONFIGS`** (Dict): Pre-configured settings for supported models

**Arctic Model Settings:**
- **Pooling Method**: CLS token (first token of sequence)
- **Normalization**: L2-normalized embeddings for cosine similarity
- **Max Length**: 4096 tokens (handles long documents)
- **Default Instruction**: "Retrieve relevant passages:"
- **Batch Size**: 24 samples per batch
  - Optimized for GPU memory on typical setups (8-24GB VRAM)
  - With 4096 max length and L-v2.0 model, fits comfortably in memory
  - Reduce to 12-16 if encountering OOM errors
  - Can increase to 32-48 on high-memory GPUs (A100, H100)

In [None]:
# Model configuration 
class ModelConfig(TypedDict):
    pooling_method: str
    normalize_embeddings: bool
    max_length: int
    default_instruction: str
    batch_size: int

@dataclass
class ArcticEncoderConfig:
    """Encoder configuration for Arctic model."""
    model_name: str
    model_config: ModelConfig
    device: torch.device
    num_gpus: int
    batch_size: int
    use_default_instruction: bool
    use_fp16: bool
    testing_mode: bool = False

# Model configurations 
MODEL_CONFIGS: Dict[str, ModelConfig] = {
    "Snowflake/snowflake-arctic-embed-l-v2.0": {
        "pooling_method": "cls",
        "normalize_embeddings": True,
        "max_length": 4096,
        "default_instruction": "Retrieve relevant passages:",
        "batch_size": 24,
    }
}

### ArcticEmbedEncoder Implementation

Implements the complete Arctic embedding encoder with the following capabilities:

**Encoder Registry:**
- `ENCODER_REGISTRY`: Maps encoder types to classes
- `get_encoder_class()`: Factory function to get encoder by type

**Design**: Each encoder instance runs on a single GPU for parallel processing.

In [None]:
class ArcticEmbedEncoder:
    """
    Arctic embedding encoder for generating high-quality text embeddings.
    """
    
    def __init__(
        self,
        model_name: str = "Snowflake/snowflake-arctic-embed-l-v2.0",
        device: Optional[torch.device] = None,
        use_fp16: bool = False,
        use_default_instruction: bool = True,
        testing_mode: bool = False,
    ) -> None:
        """Initializes encoder with specified model
            Sets up GPU device
            Creates configuration
            Calls _initialize_model()
        """
        if model_name not in MODEL_CONFIGS:
            raise ValueError(
                f"Model {model_name} not supported. Supported models: {list(MODEL_CONFIGS.keys())}"
            )

        # Use the provided device or default to CUDA
        self.device = device or torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )

        # Get device ID for logging
        self.device_id = self.device.index if hasattr(self.device, "index") else 0

        # Configuration
        self.cfg = ArcticEncoderConfig(
            model_name=model_name,
            model_config=MODEL_CONFIGS[model_name],
            device=self.device,
            num_gpus=1,  # Only use 1 GPU per encoder instance
            batch_size=MODEL_CONFIGS[model_name]["batch_size"],
            use_default_instruction=use_default_instruction,
            use_fp16=use_fp16,
            testing_mode=testing_mode,
        )

        self._initialize_model()

    def _initialize_model(self) -> None:
        """Loads tokenizer and model from local cache or HuggingFace
            In testing mode: downloads from HuggingFace
            In production: requires pre-downloaded model
            Moves model to GPU and sets to evaluation mode
        """
        home_dir = os.path.expanduser("~")
        model_path = os.path.join(
            home_dir, ".cache", "instructlab", "models", self.cfg.model_name
        )

        # In testing mode, allow direct download from HuggingFace
        if hasattr(self.cfg, "testing_mode") and self.cfg.testing_mode:
            logger.warning(
                f"Model not found locally at {model_path}. "
                "Testing mode enabled - downloading from HuggingFace..."
            )
            self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name)
            self.model = AutoModel.from_pretrained(
                self.cfg.model_name,
                add_pooling_layer=False,
                trust_remote_code=True,
            )
        else:
            if not os.path.exists(model_path):
                raise ValueError(
                    f"Model not found in available models: {self.cfg.model_name}\n"
                    "Please run `ilab model download` and download the necessary model"
                )

            self.tokenizer = AutoTokenizer.from_pretrained(model_path)
            self.model = AutoModel.from_pretrained(
                model_path,
                add_pooling_layer=False,
                trust_remote_code=True,
                local_files_only=True,
            )

        if self.cfg.use_fp16:
            self.model = self.model.half()

        self.model = self.model.to(self.cfg.device)
        logger.info(f"Model loaded on device: {self.cfg.device}")

        # Set model to evaluation mode
        self.model.eval()

    def _prepare_inputs(
        self, texts: Union[str, List[str]], instruction: str = ""
    ) -> List[str]:
        """Adds instruction prefix to texts
            Ensures instruction is always present
            Formats inputs for the model
        """
        if isinstance(texts, str):
            texts = [texts]

        # Ensure we always have an instruction
        if not instruction and not self.cfg.use_default_instruction:
            raise ValueError(
                "An instruction must be provided when use_default_instruction is False. "
                "Either provide an instruction or set use_default_instruction to True."
            )

        if (
            not instruction
            and self.cfg.use_default_instruction
            and self.cfg.model_config["default_instruction"]
        ):
            instruction = str(self.cfg.model_config["default_instruction"])

        if not instruction:  # catch if default_instruction is empty
            raise ValueError(
                "No instruction available. Either provide an instruction or ensure "
                "the model config has a valid default_instruction."
            )

        texts = [f"{instruction}: {text}" for text in texts]
        return texts

    @torch.no_grad()
    def encode(
        self,
        inputs: Union[str, List[str]],
        instruction: str = "",
        return_tensors: bool = True,
        show_progress: bool = True,
    ) -> Union[torch.Tensor, np.ndarray]:
        """Main method to generate embeddings
            Tokenizes input texts
            Processes in batches
            Applies CLS pooling and L2 normalization
            Returns PyTorch tensors or numpy arrays
        """
        input_was_string = isinstance(inputs, str)
        inputs = self._prepare_inputs(inputs, instruction)

        encodings = self.tokenizer(
            inputs,
            max_length=self.cfg.model_config["max_length"],
            padding=True,
            truncation=True,
            return_tensors="pt",
        ).to(self.cfg.device)

        embeddings_list = []
        for i in tqdm(
            range(0, len(inputs), self.cfg.batch_size),
            disable=not show_progress or len(inputs) < 256,
            desc=f"Encoding on {self.device}",
        ):
            batch = {k: v[i : i + self.cfg.batch_size] for k, v in encodings.items()}
            outputs = self.model(**batch)
            # Take the first token embedding (CLS) and normalize it
            embeddings = F.normalize(outputs.last_hidden_state[:, 0], p=2, dim=1)
            embeddings_list.append(embeddings.cpu())

        embeddings = torch.cat(embeddings_list, dim=0)
        if input_was_string:
            embeddings = embeddings[0]

        return embeddings if return_tensors else embeddings.numpy()


# Encoder Registry
ENCODER_REGISTRY = {
    "arctic": ArcticEmbedEncoder,
}

def get_encoder_class(encoder_type: str):
    """Get the encoder class based on the encoder type."""
    try:
        if encoder_type not in ENCODER_REGISTRY:
            supported_encoders = list(ENCODER_REGISTRY.keys())
            raise ValueError(
                f"Unsupported encoder type: '{encoder_type}'. "
                f"Supported types are: {supported_encoders}"
            )
        return ENCODER_REGISTRY[encoder_type]
    except Exception as e:
        raise ValueError(f"Error getting encoder class: {str(e)}") from e



### Pairwise Similarity Computation

Defines the `compute_pairwise_dense()` function for calculating similarities between embeddings.

**What This Function Does:**
- Computes pairwise metrics (cosine, euclidean, RBF) between two sets of vectors
- Processes in batches to avoid GPU memory overflow
- Supports multiple similarity metrics
- Applies optional scaling (min-max or additive)

In [None]:
def compute_pairwise_dense(
    tensor1: torch.Tensor,
    tensor2: Optional[torch.Tensor] = None,
    batch_size: int = 10000,
    metric: str = "cosine",
    device: Optional[Union[str, torch.device]] = None,
    scaling: Optional[str] = None,
    kw: float = 0.1,
) -> torch.Tensor:
    """
    Compute pairwise metric in batches between two sets of vectors.
    This function is needed for similarity computation in subset selection (Notebook 3).
    - `tensor1`, `tensor2`: Input embedding tensors
    - `batch_size`: Size of processing batches (default: 10K)
    - `metric`: Similarity metric ("cosine", "euclidean", "rbf", "dot")
    - `device`: GPU device to use
    - `scaling`: Optional scaling ("min-max", "additive", None)
    """
    assert batch_size > 0, "Batch size must be positive."

    if not device:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if tensor2 is None:
        tensor2 = tensor1

    tensor1, tensor2 = tensor1.to(device), tensor2.to(device)
    n_samples1, n_samples2 = tensor1.size(0), tensor2.size(0)
    results = torch.zeros(n_samples1, n_samples2, device="cpu")

    if metric == "cosine":
        tensor1, tensor2 = (
            F.normalize(tensor1, p=2, dim=1),
            F.normalize(tensor2, p=2, dim=1),
        )

    def calculate_metric(a: torch.Tensor, b: torch.Tensor, metric: str, kw: float) -> torch.Tensor:
        if metric in ["cosine", "dot"]:
            return torch.mm(a, b.T)
        if metric == "euclidean":
            distances = torch.cdist(a, b, p=2)
            similarities = 1 / (1 + distances**2)
            return similarities
        if metric == "rbf":
            distance = torch.cdist(a, b)
            squared_distance = distance**2
            avg_dist = torch.mean(squared_distance)
            torch.div(squared_distance, kw * avg_dist, out=squared_distance)
            torch.exp(-squared_distance, out=squared_distance)
            return squared_distance
        raise ValueError(f"Unknown metric: {metric}")

    for i in range(0, n_samples1, batch_size):
        end_i = min(i + batch_size, n_samples1)
        rows = tensor1[i:end_i]

        for j in range(0, n_samples2, batch_size):
            end_j = min(j + batch_size, n_samples2)
            cols = tensor2[j:end_j]
            batch_results = calculate_metric(rows, cols, metric, kw).cpu()
            results[i:end_i, j:end_j] = batch_results

    if scaling == "min-max":
        min_val, max_val = results.min(), results.max()
        if max_val != min_val:
            results = (results - min_val) / (max_val - min_val)
    elif scaling == "additive":
        results = (results + 1) / 2

    return results

### Multi-GPU Embedding Generation Functions
Implements parallel embedding generation across multiple GPUs.

**Workflow:**
Dataset → Split into N shards → Process on N GPUs in parallel → Merge results
n-samples → GPU 0: n-samples → embeddings.h5
(or split across multiple)

In [None]:
def _process_dataset_shard(args):
    """
    Processes one shard of data on a specific GPU
   - Creates encoder instance on assigned GPU
   - Applies templates to format text
   - Generates embeddings in batches
   - Saves embeddings to shard-specific HDF5 file
   - Includes progress bar for monitoring
    """
    (
        gpu_id,
        dataset_shard,
        output_dir,
        encoder_type,
        encoder_model,
        instruction,
        template_name,
        templates,
        batch_size,
        testing_mode,
    ) = args

    try:
        # Set the GPU for this process
        if torch.cuda.is_available():
            torch.cuda.set_device(gpu_id)
            device = f"cuda:{gpu_id}"
        else:
            device = "cpu"
            
        logger.info(f"GPU {gpu_id} started processing {len(dataset_shard)} samples")

        encoder_cls = get_encoder_class(encoder_type)
        encoder = encoder_cls(
            model_name=encoder_model,
            device=torch.device(device),
            testing_mode=testing_mode,
        )

        # Set up Jinja environment for templating
        env = Environment(loader=BaseLoader())
        templates_dict = {k: env.from_string(v) for k, v in templates.items()}

        # Create shard-specific output directory
        shard_dir = os.path.join(output_dir, f"shard_{gpu_id}")
        os.makedirs(shard_dir, exist_ok=True)

        # Process batches
        all_embeddings = []
        batch_texts = []

        # Create progress bar
        progress_bar = tqdm(
            desc=f"GPU {gpu_id} generating embeddings",
            total=len(dataset_shard),
            unit=" samples",
            position=gpu_id,
            leave=True,
        )

        # Process each example in the shard
        for example in dataset_shard:
            # Format the text using the template
            template = templates_dict.get(template_name)
            if not template:
                raise ValueError(f"Unknown format type: {template_name}")

            text = template.render(**example)
            batch_texts.append(text)

            # Process when batch is full or at the end
            if len(batch_texts) == batch_size or example == dataset_shard[-1]:
                # Generate embeddings for this batch
                with torch.no_grad():
                    batch_embeddings = (
                        encoder.encode(
                            inputs=batch_texts,
                            instruction=instruction,
                            return_tensors=False,  # Return numpy for easier handling
                        )
                    )

                all_embeddings.append(batch_embeddings)
                progress_bar.update(len(batch_texts))
                batch_texts = []

                # Clean up GPU memory
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        progress_bar.close()

        # Concatenate all batches
        if not all_embeddings:
            logger.warning(f"No embeddings generated for shard on GPU {gpu_id}")
            return None

        embeddings = np.concatenate(all_embeddings, axis=0)

        # Save embeddings to file
        shard_file = os.path.join(shard_dir, f"embeddings_shard_{gpu_id}.h5")
        with h5py.File(shard_file, "w") as h5f:
            h5f.create_dataset("embeddings", data=embeddings, dtype="float32")

        logger.info(f"GPU {gpu_id} completed processing. Saved to {shard_file}")
        return shard_file

    except Exception as e:
        logger.error(f"Error processing shard on GPU {gpu_id}: {str(e)}")
        raise


def _merge_shard_files(shard_files, merged_file):
    """
    Combines all shard files into single embeddings file
    - Preserves embedding dimension and data type
    - Removes shard files after merging
    - Creates final `embeddings.h5` file
    """
    logger.info(f"Merging {len(shard_files)} shard files into {merged_file}")

    # Get the shape and type of embeddings from the first shard
    with h5py.File(shard_files[0], "r") as f:
        first_embeddings = f["embeddings"]
        embedding_dim = first_embeddings.shape[1]
        dtype = first_embeddings.dtype

    # Count total samples across all shards
    total_samples = 0
    for shard_file in shard_files:
        with h5py.File(shard_file, "r") as f:
            total_samples += f["embeddings"].shape[0]

    # Create the merged file
    with h5py.File(merged_file, "w") as merged_f:
        merged_dataset = merged_f.create_dataset(
            "embeddings", shape=(total_samples, embedding_dim), dtype=dtype
        )

        # Copy embeddings from each shard
        start_idx = 0
        for shard_file in shard_files:
            with h5py.File(shard_file, "r") as shard_f:
                embeddings = shard_f["embeddings"][:]
                end_idx = start_idx + embeddings.shape[0]
                merged_dataset[start_idx:end_idx] = embeddings
                start_idx = end_idx

            # Remove shard file after merging
            os.remove(shard_file)
            # Remove shard directory if empty
            shard_dir = os.path.dirname(shard_file)
            if not os.listdir(shard_dir):
                os.rmdir(shard_dir)

    logger.info(
        f"Successfully merged embeddings from {len(shard_files)} GPUs with {total_samples} total samples"
    )



### DataProcessor with Embedding Generation
Adds the `generate_embeddings()` method to the DataProcessor class from Notebook 1.

**Processing Modes:**
- **Testing Mode / Single GPU**: Serial processing (notebook-friendly)
- **Production / Multi-GPU**: Parallel processing with multiprocessing.Pool
**Decorated with `@retry_on_exception`:**
- Automatically retries on GPU OOM errors
- Cleans up memory between retries
- Uses retry settings from SystemConfig

In [None]:
@retry_on_exception
def generate_embeddings(self, dataset, output_dir: str) -> str:
    """
    Generates embeddings for the dataset and saves them to the output directory,
    using multiple GPUs in parallel.

    Args:
        dataset: The dataset to process.
        output_dir (str): The directory where embeddings will be saved.

    Returns:
        str: The path to the merged embeddings file.
    """
    os.makedirs(output_dir, exist_ok=True)
    merged_path = os.path.join(output_dir, "embeddings.h5")

    # If embeddings already exist, return early
    if os.path.exists(merged_path):
        logger.info(f"Embeddings file already exists in {output_dir}, skipping")
        return merged_path

    # Get number of GPUs to use
    num_gpus = min(
        self.config.system.num_gpus,
        torch.cuda.device_count() if torch.cuda.is_available() else 1
    )
    logger.info(f"Using {num_gpus} GPUs for embedding generation")

    # Create dataset shards - one per GPU
    total_samples = len(dataset)
    per_gpu_samples = (total_samples + num_gpus - 1) // num_gpus  # Ceiling division

    # Prepare arguments for parallel processing
    args_list = []
    for gpu_id in range(num_gpus):
        # Calculate start and end indices for this shard
        start_idx = gpu_id * per_gpu_samples
        end_idx = min(start_idx + per_gpu_samples, total_samples)

        if start_idx >= total_samples:
            continue  # Skip if this GPU has no data to process

        # Create arguments for this GPU
        args_list.append(
            (
                gpu_id,
                dataset.select(range(start_idx, end_idx)),
                output_dir,
                self.config.encoder.encoder_type,
                self.config.encoder.encoder_model,
                self.config.encoder.instruction,
                self.config.template.template_name,
                self.config.template.templates,
                self.config.basic.batch_size,
                self.config.encoder.testing_mode,
            )
        )

    # Process dataset shards
    # Use serial processing in testing mode (notebook-friendly)
    # Use parallel processing in production mode
    if self.config.encoder.testing_mode or num_gpus == 1:
        logger.info("Processing shards serially (testing mode or single GPU)")
        shard_files = []
        for args in args_list:
            result = _process_dataset_shard(args)
            shard_files.append(result)
    else:
        logger.info(f"Processing shards in parallel with {num_gpus} workers")
        with Pool(processes=num_gpus) as pool:
            shard_files = pool.map(_process_dataset_shard, args_list)

    # Filter out None values (failed shards)
    shard_files = [f for f in shard_files if f is not None]

    if not shard_files:
        raise ValueError("No embeddings were generated from any GPU")

    # Merge all shard files
    _merge_shard_files(shard_files, merged_path)

    return merged_path


# Add the method to DataProcessor class
DataProcessor.generate_embeddings = generate_embeddings

print("\n" + "=" * 60)
print("✅ All Notebook 2 Components Loaded Successfully!")
print("=" * 60)
print("📦 Components available:")
print("   • Arctic Encoder classes and registry")
print("   • Pairwise similarity computation function")
print("   • Multi-GPU processing functions")
print("   • generate_embeddings() method added to DataProcessor")
print("=" * 60)
print("\n🎯 Ready to generate embeddings!\n")

### 🚀 Execute Embedding Generation

Run this cell to process your dataset and generate embeddings.

**What happens:**
1. Validates dataset and configuration
2. Splits data across available GPUs
3. Generates embeddings using Arctic encoder
4. Merges results into single HDF5 file
5. Reports timing and performance metrics

In [None]:
# Execute Multi-GPU Embedding Generation

if 'dataset' in locals() and dataset is not None:
    print("🎯 Multi-GPU Embedding Generation")
    print("=" * 50)
    
    # Create DataProcessor instance if it doesn't exist
    if 'processor' not in locals():
        processor = DataProcessor(config)
        print("✅ Created DataProcessor instance")
    
    # Set up output directory
    output_dir = os.path.join(config.basic.output_dir, "embeddings")
    
    print(f"\n📊 Dataset size: {len(dataset):,} samples")
    print(f"💾 Output directory: {output_dir}")
    print(f"🎯 Number of GPUs: {config.system.num_gpus}")
    print(f"📦 Batch size: {config.basic.batch_size}")
    
    # Generate embeddings using the extended DataProcessor
    print(f"\n🚀 Starting multi-GPU embedding generation...")
    start_time = time.time()
    
    try:
        embeddings_file = processor.generate_embeddings(dataset, output_dir)
        
        generation_time = time.time() - start_time
        
        print(f"\n✅ Embedding generation completed!")
        print(f"⏱️  Total time: {generation_time / 60:.2f} minutes")
        print(f"🎯 Speed: {len(dataset) / generation_time:.2f} samples/sec")
        print(f"💾 Embeddings saved to: {embeddings_file}")
        
        # Show file size
        file_size = os.path.getsize(embeddings_file) / 1024**2
        print(f"📁 File size: {file_size:.1f} MB")
        
    except Exception as e:
        print(f"❌ Error during embedding generation: {e}")
        embeddings_file = None
        
else:
    print("❌ Cannot generate embeddings without dataset")
    print("Please ensure Notebook 1 has been run and dataset was loaded successfully")

### 🎯 Next Steps

**Option 1: Continue to Notebook 3 (Recommended)**
1. Open `subset_selection.ipynb`
2. Run the notebook sequentially
3. Use the embeddings generated here for subset selection
4. This is the FAST path - no redundant computation!

**Option 2: Reuse These Embeddings Later**
- The embeddings file is saved and can be loaded anytime
- Path: `{config.basic.output_dir}/embeddings/embeddings.h5`
- Use `h5py.File()` to load them in any notebook or script