# Subset Selection Notebook 1: Data Preparation & Configuration

## Overview
This notebook is the first step in the **Subset Selection Pipeline**. It focuses on setting up the environment, configuring all necessary parameters, and preparing the input data for processing.

## Purpose in Subset Selection
Configuration and data preparation are critical foundation steps that enable efficient subset selection. This notebook:
1. Sets up all parameters for data processing, encoding, templates, and system resources
2. Loads and inspects input data to ensure it's properly formatted
3. Establishes the DataProcessor class and utility functions used throughout the pipeline
4. Automatically detects and configures available GPU resources for parallel processing

## Output
- **config**: ProcessingConfig object containing all pipeline parameters
- **dataset**: Loaded and validated HuggingFace dataset
- **processor**: DataProcessor instance ready for embedding generation
- Used in Notebook 2 for embedding generation and Notebook 3 for subset selection

## Introduction and Setup

In [None]:
# Install the necessary libraries
%pip install datasets jinja2 tqdm h5py numpy torch transformers submodlib-py==0.0.3 nbformat

## Imports and Logging

In [None]:
# Standard Imports
from dataclasses import dataclass, field
from typing import Any, Dict, List, TypedDict, TypeVar, Union
import logging
import os
import re

# Third Party Imports
from datasets import load_dataset, concatenate_datasets
from jinja2 import BaseLoader, Environment
from tqdm import tqdm
import torch
import numpy as np
import warnings

# Configure logging and warnings
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning)

### Core Utility Functions

These functions provide critical infrastructure for the embedding generation and subset selection pipeline:

1. **`get_default_num_gpus(testing_mode)`**
   - **Purpose**: Auto-detects GPUs for distributed embedding generation
   - **Used in**: SystemConfig initialization, parallel encoding across GPUs
   - **Behavior**: Returns number of available CUDA devices; falls back to CPU in testing mode
   - **Error Handling**: Raises RuntimeError if no GPU found in production mode

2. **`retry_on_exception(func)`**
   - **Purpose**: Automatic retry with cleanup for transient GPU/computation errors
   - **Used in**: Embedding generation and subset selection methods (Notebooks 2 & 3)
   - **Handles**: GPU OOM errors, runtime errors, value/type/index errors
   - **Recovery**: Cleans GPU memory and waits before retry (configurable delay)

3. **`display_gpu_info()`**
   - **Purpose**: Display detailed information about available GPU devices
   - **Shows**: GPU count, device names, memory (total/allocated/free), current device
   - **Useful for**: Debugging, resource planning, monitoring GPU usage
   - **Returns**: Dictionary with GPU information

In [None]:
from functools import wraps
import gc
import time

def get_default_num_gpus(testing_mode: bool = False) -> int:
    """
    Get the default number of GPUs based on available CUDA devices.
    
    Args:
        testing_mode (bool): If True, allows CPU usage with warnings. For testing only.
    """
    if not torch.cuda.is_available():
        if testing_mode:
            logger.warning(
                "No CUDA devices detected. Running in testing mode with CPU. "
                "Production use requires GPU acceleration."
            )
            return 1
        raise RuntimeError(
            "No CUDA devices detected. This functionality requires at least one GPU."
        )
    return torch.cuda.device_count()


def retry_on_exception(func):
    """
    Decorator to retry a function upon exception up to a maximum number of retries.
    """
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        last_exception = None
        for attempt in range(self.config.system.max_retries):
            try:
                return func(self, *args, **kwargs)
            except torch.cuda.OutOfMemoryError as e:
                last_exception = e
                logger.error(f"GPU out of memory on attempt {attempt + 1}: {str(e)}")
            except RuntimeError as e:
                last_exception = e
                logger.error(f"PyTorch runtime error on attempt {attempt + 1}: {str(e)}")
            except ValueError as e:
                last_exception = e
                logger.error(f"Value error on attempt {attempt + 1}: {str(e)}")
            except TypeError as e:
                last_exception = e
                logger.error(f"Type error on attempt {attempt + 1}: {str(e)}")
            except IndexError as e:
                last_exception = e
                logger.error(f"Index error on attempt {attempt + 1}: {str(e)}")

            if attempt < self.config.system.max_retries - 1:
                logger.info(f"Retrying in {self.config.system.retry_delay} seconds...")
                time.sleep(self.config.system.retry_delay)
                gc.collect()
                torch.cuda.empty_cache()

        raise last_exception

    return wrapper

def display_gpu_info():
    """
    Display detailed information about available GPU devices.
    
    Returns:
        dict: Dictionary containing GPU information
    """
    gpu_info = {
        'cuda_available': torch.cuda.is_available(),
        'gpu_count': 0,
        'gpus': [],
        'current_device': None
    }
    
    if not torch.cuda.is_available():
        print("\n" + "="*60)
        print("🖥️  GPU Information")
        print("="*60)
        print("❌ CUDA is not available")
        print("💡 Running on CPU")
        print("="*60 + "\n")
        return gpu_info
    
    gpu_info['gpu_count'] = torch.cuda.device_count()
    gpu_info['current_device'] = torch.cuda.current_device()
    
    print("\n" + "="*60)
    print("🖥️  GPU Information")
    print("="*60)
    print(f"✅ CUDA is available")
    print(f"📊 Number of GPUs: {gpu_info['gpu_count']}")
    print(f"🎯 Current GPU device: {gpu_info['current_device']}")
    print("="*60)
    
    for i in range(gpu_info['gpu_count']):
        device_props = torch.cuda.get_device_properties(i)
        total_memory = device_props.total_memory / 1024**3  # Convert to GB
        allocated_memory = torch.cuda.memory_allocated(i) / 1024**3
        reserved_memory = torch.cuda.memory_reserved(i) / 1024**3
        free_memory = total_memory - reserved_memory
        
        gpu_data = {
            'id': i,
            'name': device_props.name,
            'total_memory_gb': round(total_memory, 2),
            'allocated_memory_gb': round(allocated_memory, 2),
            'reserved_memory_gb': round(reserved_memory, 2),
            'free_memory_gb': round(free_memory, 2),
            'compute_capability': f"{device_props.major}.{device_props.minor}",
            'multi_processor_count': device_props.multi_processor_count
        }
        gpu_info['gpus'].append(gpu_data)
        
        marker = "🎯" if i == gpu_info['current_device'] else "  "
        print(f"\n{marker} GPU {i}: {device_props.name}")
        print(f"   • Compute Capability: {device_props.major}.{device_props.minor}")
        print(f"   • Multi-processors: {device_props.multi_processor_count}")
        print(f"   • Total Memory: {total_memory:.2f} GB")
        print(f"   • Allocated: {allocated_memory:.2f} GB")
        print(f"   • Reserved: {reserved_memory:.2f} GB")
        print(f"   • Free: {free_memory:.2f} GB ({(free_memory/total_memory)*100:.1f}%)")
    
    print("="*60 + "\n")
    return gpu_info

gpu_info = display_gpu_info()
print("✅ Utility functions defined!")

## Configuration Classes

### BasicConfig Class

Defines basic processing parameters with validation and helpful metadata.

**Key Parameters:**
- `output_dir`: Directory where results will be saved
- `batch_size`: Number of samples processed per batch (default: 100K for efficiency)
- `num_folds`: Number of folds for cross-validation in subset selection
- `combine_files`: Whether to merge multiple input files into one dataset
- `epsilon`: Optimization parameter for submodular facility location
  - Default: 160.0 (optimized for datasets >100K samples)
  - For smaller datasets: use values starting from 0.1

In [None]:
@dataclass
class BasicConfig:
    """Basic configuration parameters"""
    output_dir: str = "../../assets/subset-selection/outputs" # change this to your desired output directory
    batch_size: int = 100000
    num_folds: int = 50
    combine_files: bool = False
    epsilon: float = field(
        default=160.0,
        metadata={
            "advanced": True,
            "help": "Epsilon parameter for the LazierThanLazyGreedy optimizer in facility location maximization. "
            "Default of 160.0 is optimized for datasets >100k samples. "
            "For smaller datasets, consider using much smaller values (starting from 0.1).",
        },
    )

    def __post_init__(self):
        """Validate configuration after initialization"""
        if not 0 < self.epsilon <= 160:
            raise ValueError("epsilon must be between 0 and 160")

    def validate_epsilon_for_dataset_size(self, dataset_size: int)->None:
        """
        Validate epsilon parameter based on dataset size and provide appropriate warnings.

        Args:
            dataset_size (int): Size of the dataset being processed
        """
        if dataset_size < 100000:
            logger.warning(
                "Subset selection is highly recommended to be used only with dataset sizes over 100k samples. "
                f"Your dataset has {dataset_size:,} samples."
            )
            if self.epsilon > 1.0:
                logger.warning(
                    f"Current epsilon value ({self.epsilon}) may be too high for a dataset of this size. "
                    "For smaller datasets, consider using much smaller values (starting from 0.1) "
                    "to ensure proper subset selection."
                )

### EncoderConfig Class

Configures the embedding encoder. Separates encoder settings from other parameters for modularity.

**Key Parameters:**
- `instruction`: Prompt prefix for the encoder to guide embedding generation
- `encoder_type`: Type of encoder to use (e.g., "arctic")
- `encoder_model`: Specific model identifier (e.g., "Snowflake/snowflake-arctic-embed-l-v2.0")
- `testing_mode`: If True, enables development features (CPU fallback, model auto-download)

In [None]:
@dataclass
class EncoderConfig:
    """Encoder-specific configuration parameters."""
    instruction: str = field(
        default="Generate embeddings that capture the core meaning of user-assistant conversations, ensuring the embeddings can be clustered based on semantic similarity for subset selection.",
        metadata={"advanced": True},
    )
    encoder_type: str = field(default="arctic", metadata={"advanced": True})
    encoder_model: str = field(
        default="Snowflake/snowflake-arctic-embed-l-v2.0", metadata={"advanced": True}
    )
    testing_mode: bool = False

### TemplateConfig Class

Manages text formatting templates to enable flexible formatting for different data structures.

**Key Parameters:**
- `template_name`: Active template to use (e.g., "conversation")
- `templates`: Dictionary of available templates with Jinja2 syntax
  - `default`: Simple text passthrough
  - `conversation`: Formats multi-turn dialogues (user/assistant)
  - `qa`: Question-answer format

**Usage**: Templates convert structured data (dicts/lists) into plain text for embedding generation.

In [None]:
@dataclass
class TemplateConfig:
    """Template-related configuration parameters."""
    template_name: str = field(default="conversation", metadata={"advanced": True})
    templates: Dict[str, str] = field(
        default_factory=lambda: {
            "default": "{{ text }}",
            "conversation": "{% for msg in messages if msg.role != 'system' %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}",
            "qa": "Question: {{ question }}\nAnswer: {{ answer }}",
        },
        metadata={"advanced": True},
    )

### SystemConfig Class

Manages system-level configuration for handling resources and error recovery.

**Key Parameters:**
- `num_gpus`: Auto-detects available GPUs (set automatically in `__post_init__`)
- `seed`: Random seed for reproducibility (default: 42)
- `max_retries`: Number of retry attempts for failed operations (default: 3)
- `retry_delay`: Seconds to wait between retries (default: 30)
- `testing_mode`: Enables testing features (CPU fallback, reduced validation)

In [None]:
@dataclass
class SystemConfig:
    """System-related configuration parameters."""
    num_gpus: int = field(init=False)
    seed: int = field(default=42, metadata={"advanced": True})
    max_retries: int = field(default=3, metadata={"advanced": True})
    retry_delay: int = field(default=30, metadata={"advanced": True})
    testing_mode: bool = field(default=False, metadata={"advanced": True})

    def __post_init__(self):
        """Initialize num_gpus after other fields are set."""
        self.num_gpus = get_default_num_gpus(testing_mode=self.testing_mode)

### ProcessingConfig Class

Main configuration class that combines all other configurations. Provides a single point of configuration with comprehensive validation.

**Required Parameters:**
- `input_files`: List of input file paths to process (JSONL format)
- `subset_sizes`: List of target subset sizes
  - Use **floats** (0-1) for percentages: `[0.1, 0.05]` = 10% and 5%
  - Use **integers** for absolute counts: `[1000, 500]` = 1000 and 500 samples

**Configuration Groups:**
- `basic`: Basic processing parameters (output dir, batch size, epsilon)
- `encoder`: Encoder-specific parameters (model, instruction, testing mode)
- `template`: Template-related parameters (template name, Jinja2 templates)
- `system`: System-related parameters (GPUs, seed, retries)

**Validation**: Automatically validates subset sizes and parameter ranges in `__post_init__`.

In [None]:
@dataclass
class ProcessingConfig:
    """
    Configuration for subset selection with basic and advanced parameters.
    
    """
    # Required parameters
    input_files: List[str]
    subset_sizes: List[Union[int, float]]

    # Configuration groups
    basic: BasicConfig = field(default_factory=BasicConfig)
    encoder: EncoderConfig = field(default_factory=EncoderConfig)
    template: TemplateConfig = field(default_factory=TemplateConfig)
    system: SystemConfig = field(default_factory=SystemConfig)

    def __post_init__(self):
        """Validate configuration after initialization."""
        if not isinstance(self.subset_sizes, list):
            raise ValueError("subset_sizes must be a list")

        for size in self.subset_sizes:
            if not isinstance(size, (int, float)):
                raise ValueError("subset_sizes must contain only integers or floats")
            if isinstance(size, float) and not 0 < size <= 100:
                raise ValueError(
                    "Percentage values in subset_sizes must be between 0 and 100"
                )
            if isinstance(size, int) and size <= 0:
                raise ValueError("Absolute values in subset_sizes must be positive")

## Data Processor Class

Main processing class that handles the complete data preparation pipeline with proper error handling.

**Core Responsibilities:**
- **Data Loading**: Loads and optionally combines multiple datasets from files
- **Text Formatting**: Applies Jinja2 templates to convert structured data to text
- **Subset Calculations**: Converts percentage/absolute subset sizes to actual sample counts
- **Configuration Management**: Maintains all configuration and processing state
- **Device Management**: Sets up GPU/CPU devices and random seeds

In [None]:
class DataProcessor:
    """
    Enhanced data processor with support for combined files and multiple selection methods.
    
    This class handles the complete pipeline:
    - Data loading and formatting
    - Embedding generation (to be implemented in Notebook 2)
    - Subset selection (to be implemented in Notebook 3)
    - Result export
    """

    def __init__(self, config: ProcessingConfig):
        """
        Initializes the DataProcessor with the given configuration.

        Args:
            config (ProcessingConfig): The processing configuration.
        """
        self.config = config
        self.env = Environment(loader=BaseLoader())
        self.templates = {
            k: self.env.from_string(v) for k, v in config.template.templates.items()
        }
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Set random seeds
        np.random.seed(config.system.seed)
        torch.manual_seed(config.system.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(config.system.seed)

    def format_text(self, example: Dict[str, Any], format_type: str) -> str:
        """
        Formats the text of an example using the specified template.

        Args:
            example (Dict[str, Any]): The data example to format.
            format_type (str): The key of the template to use.

        Returns:
            str: The formatted text.
        """
        template = self.templates.get(format_type)
        if not template:
            raise ValueError(f"Unknown format type: {format_type}")
        return template.render(**example)

    def load_and_combine_datasets(self, input_files: List[str]):
        """
        Load and optionally combine multiple datasets.

        Args:
            input_files (List[str]): List of input file paths.

        Returns:
            Combined dataset or list of individual datasets.
        """
        datasets = []

        for input_file in input_files:
            file_extension = input_file.split(".")[-1]
            if file_extension == "jsonl":
                file_extension = "json"
            dataset = load_dataset(
                file_extension, data_files=input_file, split="train", cache_dir=None
            )
            datasets.append(dataset)

        if self.config.basic.combine_files:
            logger.info("Combining datasets...")
            return concatenate_datasets(datasets)

        if len(datasets) > 1:
            raise ValueError(
                "Multiple datasets provided but combine_files is not enabled"
            )
        return datasets[0]

    def calculate_subset_size(self, total_samples: int, size_spec: Union[int, float]) -> int:
        """
        Calculate the actual subset size based on the specification.
        
        Args:
            total_samples (int): Total number of samples in the dataset.
            size_spec (Union[int, float]): Size specification (percentage if float, absolute if int).

        Returns:
            int: Actual number of samples to select.
        """
        if isinstance(size_spec, float):
            # Handle percentage (0.1 = 10%, 0.05 = 5%)
            if size_spec <= 0 or size_spec > 1:
                raise ValueError(
                    "Percentage values must be between 0(non-inclusive) and 1(inclusive)"
                )
            return max(1, int(size_spec * total_samples))
        # Treat as absolute number
        return min(size_spec, total_samples)

    def get_subset_name(self, size_spec: Union[int, float], actual_size: int) -> str:
        """
        Generate appropriate subset name based on selection method.

        Args:
            size_spec (Union[int, float]): Original size specification.
            actual_size (int): Actual number of samples selected.

        Returns:
            str: Descriptive name for the subset.
        """
        if isinstance(size_spec, float):
            # Use :g format to automatically remove trailing zeros
            # 0.1 -> "0.1", 0.05 -> "0.05", 0.10 -> "0.1"
            return f"percent_{size_spec:g}"
        return f"samples_{actual_size}"

    def get_dataset_name(self, input_file: str) -> str:
        """
        Get a clean dataset name from the input file path.

        Args:
            input_file (str): Input file path

        Returns:
            str: Clean dataset name
        """
        base_name = os.path.splitext(os.path.basename(input_file))[0]
        clean_name = re.sub(r"[^\w\-_]", "_", base_name)
        return clean_name

print("✅ DataProcessor class defined successfully with all utility functions!")

## Configuration Setup

Create a complete configuration for the subset selection pipeline.

**Key Configuration Decisions:**

1. **Input Files**: Update `input_files` path to your JSONL dataset
   - Expected format: One JSON object per line with `messages` field
   
2. **Subset Sizes**: `[0.1, 0.05]` creates two subsets (10% and 5% of original data)
   - Use floats (0-1) for percentages of the dataset
   - Use integers for absolute sample counts (e.g., `[1000, 500]`)
   
3. **Batch Size**: 100,000 samples per batch balances memory usage and processing speed
   - Reduce if encountering OOM errors
   
4. **Epsilon**: 160.0 optimized for large datasets (>100K samples)
   - Controls the trade-off between quality and speed in facility location
   - For smaller datasets (<100K), use values starting from 0.1
   
5. **Testing Mode**: Enabled for development (allows CPU, auto-downloads models)
   - Disable for production use with pre-downloaded models and GPUs

In [None]:
config = ProcessingConfig(
    input_files=["../../assets/subset-selection/combined_cut_50x.jsonl"], # Update with your data path
    subset_sizes=[0.1, 0.05],  # 10% and 5% subsets
    basic=BasicConfig(
        output_dir="../../assets/subset-selection/outputs",  # Change to your desired directory (used by Notebooks 2 & 3)
        batch_size=100000,
        num_folds=25,
        epsilon=160.0,
        combine_files=False
    ),
    encoder=EncoderConfig(
        encoder_type="arctic",
        encoder_model="Snowflake/snowflake-arctic-embed-l-v2.0",
        testing_mode=True  # Enable for notebook development
    ),
    template=TemplateConfig(
        template_name="conversation",
        templates={
            "default": "{{ text }}",
            "conversation": "{% for msg in messages if msg.role != 'system' %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}",
            "qa": "Question: {{ question }}\nAnswer: {{ answer }}",
        }
    ),
    system=SystemConfig(
        seed=42,
        testing_mode=True,
        max_retries=3,
        retry_delay=30
    )
)

print("===== Configuration created successfully! =====")
print(f"Input files: {config.input_files}")
print(f"Subset sizes: {config.subset_sizes}")
print(f"Output directory: {config.basic.output_dir}")
print(f"Encoder type: {config.encoder.encoder_type}")
print(f"Template name: {config.template.template_name}")
print(f"Number of GPUs: {config.system.num_gpus}")
print(f"Testing mode: {config.system.testing_mode}")

## Data Loading and Validate input data
- Checks if data file exists
- Loads dataset using HuggingFace datasets
- Shows dataset statistics

In [None]:
print("📁 Data Loading")
print("=" * 50)

# Check if data file exists
data_file = config.input_files[0]
if not os.path.exists(data_file):
    print(f"❌ Data file not found: {data_file}")
    print("Please update the data_file path in the configuration above")
    print("Example data file structure:")
    print("""
    [
        {
            "messages": [
                {"role": "user", "content": "Hello, how are you?"},
                {"role": "assistant", "content": "I'm doing well, thank you!"}
            ]
        },
        ...
    ]
    """)
else:
    print(f"✅ Found data file: {data_file}")
    
    # Load the dataset
    try:
        dataset = load_dataset("json", data_files=data_file, split="train", cache_dir=None)
        print(f"✅ Dataset loaded successfully!")
        print(f"📊 Dataset size: {len(dataset):,} samples")
        
        # Show file size
        file_size = os.path.getsize(data_file) / 1024**2  # MB
        print(f"📁 File size: {file_size:.1f} MB")
        
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        dataset = None

## Data Inspection - show data structure
It analyzes loaded data structure and quality
- Displays sample data
- Analyzes column structure
- Counts messages and roles
- validates epsilon for dataset size

In [None]:
if dataset is not None:
    print("🔍 Data Inspection")
    print("=" * 50)
    
    # Show sample data
    print("📋 Sample data:")
    for i in range(min(3, len(dataset))):
        print(f"\nSample {i+1}:")
        sample = dataset[i]
        for key, value in sample.items():
            if isinstance(value, str) and len(value) > 100:
                print(f"  {key}: {value[:100]}...")
            else:
                print(f"  {key}: {value}")
    
    # Analyze data structure
    print(f"\n📊 Data structure analysis:")
    print(f"   Number of samples: {len(dataset):,}")
    
    # Check column names
    if hasattr(dataset, 'column_names'):
        print(f"   Column names: {dataset.column_names}")
    
    # Analyze message structure if it exists
    if 'messages' in dataset.column_names:
        message_lengths = []
        role_counts = {'user': 0, 'assistant': 0, 'system': 0}
        
        for sample in dataset.select(range(min(1000, len(dataset)))):
            if 'messages' in sample:
                message_lengths.append(len(sample['messages']))
                for msg in sample['messages']:
                    if 'role' in msg:
                        role_counts[msg['role']] += 1
        
        print(f"Average messages per conversation: {np.mean(message_lengths):.1f}")
        print(f"Role distribution: {role_counts}")
    
    # Validate epsilon for dataset size
    config.basic.validate_epsilon_for_dataset_size(len(dataset))
    
else:
    print("❌ No dataset loaded. Please fix the data loading issue above.")

### 🎯 Next Steps

**Proceed to Notebook 2: Embedding Generation**

1. Open `embedding_generation.ipynb`
2. Run all cells sequentially
3. The notebook will:
   - Import all objects from this notebook using `%run`
   - Add `generate_embeddings()` method to `DataProcessor`
   - Generate embeddings using Arctic encoder on GPU(s)
   - Save embeddings to `{output_dir}/embeddings/embeddings.h5`