# Classifier Module

> Core classification functions for data and code availability statements.

This module provides:
- `classify_statement()`: Classify a single availability statement
- `classify_publication()`: Classify both data and code for a publication

The classifier uses few-shot learning with semantically similar examples selected via kNN.

In [None]:
#| default_exp classifier

In [None]:
#| export
from __future__ import annotations
from typing import Optional, List, Tuple
from datetime import datetime
import logging

from openness_classifier.core import (
    OpennessCategory,
    ClassificationType,
    Classification,
    LLMConfiguration,
    LLMProvider,
    ClassificationLogger,
    ClassificationError,
    LLMError,
)
from openness_classifier.config import ClassifierConfig, load_config
from openness_classifier.data import (
    TrainingExample,
    Publication,
    EmbeddingModel,
    load_training_data,
    compute_embeddings,
)
from openness_classifier.prompts import (
    select_knn_examples,
    build_few_shot_prompt,
    parse_classification_response,
    SYSTEM_PROMPT,
)

## Classifier Class

Main classifier that manages training data, embeddings, and LLM calls.

In [None]:
#| export
class OpennessClassifier:
    """Few-shot LLM classifier for data and code openness.
    
    Manages training examples, embeddings, and LLM inference.
    
    Example:
        >>> classifier = OpennessClassifier.from_config(load_config())
        >>> result = classifier.classify_statement(
        ...     "Data available at https://zenodo.org/record/12345",
        ...     ClassificationType.DATA
        ... )
        >>> print(result.category)  # OpennessCategory.OPEN
    """
    
    def __init__(
        self,
        config: ClassifierConfig,
        data_examples: List[TrainingExample],
        code_examples: List[TrainingExample],
        embedding_model: EmbeddingModel,
        logger: Optional[ClassificationLogger] = None,
    ):
        self.config = config
        self.data_examples = data_examples
        self.code_examples = code_examples
        self.embedding_model = embedding_model
        self.llm_provider = LLMProvider(config.llm)
        self.logger = logger
    
    @classmethod
    def from_config(cls, config: ClassifierConfig) -> 'OpennessClassifier':
        """Create classifier from configuration.
        
        Loads training data and computes embeddings.
        """
        # Load training data
        data_examples, code_examples = load_training_data(config.training_data_path)
        
        # Initialize embedding model
        embedding_model = EmbeddingModel(config.embedding_model)
        
        # Compute embeddings
        compute_embeddings(data_examples, embedding_model)
        compute_embeddings(code_examples, embedding_model)
        
        # Setup logger
        log_path = config.log_dir / f"classifications_{datetime.now().strftime('%Y%m%d')}.jsonl"
        logger = ClassificationLogger(log_path)
        
        return cls(
            config=config,
            data_examples=data_examples,
            code_examples=code_examples,
            embedding_model=embedding_model,
            logger=logger,
        )
    
    def classify_statement(
        self,
        statement: str,
        statement_type: ClassificationType,
        return_reasoning: bool = True,
        publication_id: Optional[str] = None,
    ) -> Classification:
        """Classify a single availability statement.
        
        Args:
            statement: The availability statement text
            statement_type: DATA or CODE
            return_reasoning: Include chain-of-thought reasoning
            publication_id: Optional ID for logging
            
        Returns:
            Classification result with category, confidence, and reasoning
        """
        # Select appropriate training examples
        examples = (self.data_examples if statement_type == ClassificationType.DATA 
                   else self.code_examples)
        
        # Select kNN examples
        selected = select_knn_examples(
            statement=statement,
            training_examples=examples,
            embedding_model=self.embedding_model,
            k=self.config.few_shot_k,
        )
        
        # Build prompt
        prompt = build_few_shot_prompt(
            statement=statement,
            statement_type=statement_type,
            examples=selected,
            include_reasoning=return_reasoning,
        )
        
        # Prepend system prompt
        full_prompt = f"{SYSTEM_PROMPT}\n\n{prompt}"
        
        # Call LLM
        response = self.llm_provider.complete(full_prompt)
        
        # Parse response
        category, confidence, reasoning = parse_classification_response(response)
        
        # Create classification result
        classification = Classification(
            category=category,
            statement_type=statement_type,
            confidence_score=confidence,
            reasoning=reasoning if return_reasoning else None,
            model_config=self.config.llm,
            few_shot_example_ids=[ex.id for ex in selected],
        )
        
        # Log classification
        if self.logger and publication_id:
            self.logger.log_classification(
                publication_id=publication_id,
                classification=classification,
                statement_text=statement,
            )
        
        return classification
    
    def classify_publication(
        self,
        publication: Publication,
        return_reasoning: bool = True,
    ) -> Tuple[Optional[Classification], Optional[Classification]]:
        """Classify both data and code availability for a publication.
        
        Args:
            publication: Publication with data/code statements
            return_reasoning: Include reasoning in results
            
        Returns:
            Tuple of (data_classification, code_classification)
            Either can be None if statement is missing
        """
        data_result = None
        code_result = None
        
        if publication.has_data_statement():
            data_result = self.classify_statement(
                statement=publication.data_statement,
                statement_type=ClassificationType.DATA,
                return_reasoning=return_reasoning,
                publication_id=publication.id,
            )
        
        if publication.has_code_statement():
            code_result = self.classify_statement(
                statement=publication.code_statement,
                statement_type=ClassificationType.CODE,
                return_reasoning=return_reasoning,
                publication_id=publication.id,
            )
        
        return data_result, code_result

## Convenience Functions

Module-level functions for simpler usage.

In [None]:
#| export
_default_classifier: Optional[OpennessClassifier] = None


def get_classifier(config: Optional[ClassifierConfig] = None) -> OpennessClassifier:
    """Get or create the default classifier instance.
    
    Args:
        config: Optional configuration (loads from env if not provided)
        
    Returns:
        OpennessClassifier instance
    """
    global _default_classifier
    
    if _default_classifier is None or config is not None:
        if config is None:
            config = load_config()
        _default_classifier = OpennessClassifier.from_config(config)
    
    return _default_classifier


def classify_statement(
    statement: str,
    statement_type: ClassificationType | str,
    config: Optional[ClassifierConfig] = None,
    return_reasoning: bool = True,
) -> Classification:
    """Classify a single availability statement.
    
    Convenience function that manages classifier lifecycle.
    
    Args:
        statement: The availability statement text
        statement_type: "data" or "code" (or ClassificationType)
        config: Optional configuration
        return_reasoning: Include reasoning in result
        
    Returns:
        Classification result
        
    Example:
        >>> result = classify_statement(
        ...     "Data available at https://zenodo.org/record/12345",
        ...     "data"
        ... )
        >>> print(result.category.value)  # 'open'
    """
    # Convert string to enum if needed
    if isinstance(statement_type, str):
        statement_type = ClassificationType(statement_type.lower())
    
    classifier = get_classifier(config)
    return classifier.classify_statement(
        statement=statement,
        statement_type=statement_type,
        return_reasoning=return_reasoning,
    )


def classify_publication(
    data_statement: Optional[str] = None,
    code_statement: Optional[str] = None,
    publication_id: str = "unknown",
    config: Optional[ClassifierConfig] = None,
) -> Tuple[Optional[Classification], Optional[Classification]]:
    """Classify data and code availability for a publication.
    
    Args:
        data_statement: Data availability statement (optional)
        code_statement: Code availability statement (optional)
        publication_id: Identifier for logging
        config: Optional configuration
        
    Returns:
        Tuple of (data_classification, code_classification)
    """
    pub = Publication(
        id=publication_id,
        data_statement=data_statement,
        code_statement=code_statement,
    )
    
    classifier = get_classifier(config)
    return classifier.classify_publication(pub)

## Low Confidence Identification

In [None]:
#| export
def identify_low_confidence(
    classifications: List[Classification],
    threshold: float = 0.5
) -> List[Classification]:
    """Identify classifications with low confidence scores.
    
    Use this to find statements that may need manual review.
    
    Args:
        classifications: List of classification results
        threshold: Confidence threshold (default: 0.5)
        
    Returns:
        List of low-confidence classifications
    """
    return [c for c in classifications if c.confidence_score < threshold]


def suggest_training_examples(
    classifications: List[Tuple[str, Classification]],
    threshold: float = 0.5,
    max_suggestions: int = 10
) -> List[Tuple[str, Classification]]:
    """Suggest statements that would benefit from manual coding.
    
    Returns low-confidence classifications that should be manually
    reviewed and potentially added to training data.
    
    Args:
        classifications: List of (statement_text, classification) tuples
        threshold: Confidence threshold
        max_suggestions: Maximum suggestions to return
        
    Returns:
        List of (statement, classification) tuples needing review
    """
    low_conf = [
        (stmt, cls) for stmt, cls in classifications
        if cls.confidence_score < threshold
    ]
    
    # Sort by confidence (lowest first)
    low_conf.sort(key=lambda x: x[1].confidence_score)
    
    return low_conf[:max_suggestions]

In [None]:
# Note: Full testing requires API keys and is done in examples/01_single_classification.ipynb
print("Classifier module ready!")
print("To test, set up your .env file with API keys and run:")
print("  from openness_classifier import classify_statement")
print("  result = classify_statement('Data available at Zenodo', 'data')")

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()