# Core Module

> Base types, enumerations, and error classes for the openness classifier.

This module provides the foundational data structures used throughout the library:

- **OpennessCategory**: 4-category ordinal taxonomy for classification
- **ClassificationType**: Whether classifying data or code statements
- **Classification**: Result of a classification with metadata
- **LLMProvider**: Abstraction for multi-provider LLM support
- **Error classes**: Custom exceptions for error handling

In [None]:
#| default_exp core

In [None]:
#| export
from __future__ import annotations
from enum import Enum
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional, Dict, Any, List
from pathlib import Path
import json
import hashlib
import time
import logging
import os

## Enumerations

### OpennessCategory

The 4-category ordinal taxonomy for classifying data/code openness, based on the rubric from articles_reviewed.csv:

| Category | Description | Examples |
|----------|-------------|----------|
| open | Fully accessible, no restrictions | Zenodo, Figshare, public GitHub |
| mostly_open | Largely accessible with minor restrictions | Registration required, institutional access |
| mostly_closed | Largely restricted with limited access | Data use agreements, partial availability |
| closed | Not accessible | "Upon request", confidential, proprietary |

In [None]:
#| export
class OpennessCategory(str, Enum):
    """4-category ordinal taxonomy for data/code openness classification.
    
    Categories are ordered from most open to least open:
    open > mostly_open > mostly_closed > closed
    
    Classification Rules (per articles_reviewed.csv rubric):
    - OPEN: Public repository with no barriers (Zenodo, Figshare, GitHub public)
    - MOSTLY_OPEN: Public repository with registration, institutional access
    - MOSTLY_CLOSED: Data use agreements, partial availability, some restrictions
    - CLOSED: "Available upon request", confidential, not accessible
    """
    OPEN = "open"
    MOSTLY_OPEN = "mostly_open"
    MOSTLY_CLOSED = "mostly_closed"
    CLOSED = "closed"
    
    @classmethod
    def from_string(cls, value: str) -> 'OpennessCategory':
        """Parse category from string, handling various formats.
        
        Handles mappings from articles_reviewed.csv:
        - 'Closed', 'closed' -> CLOSED
        - 'Partially Closed', 'mostly closed', 'mostly_closed' -> MOSTLY_CLOSED
        - 'Partially Open', 'mostly open', 'mostly_open' -> MOSTLY_OPEN  
        - 'Open', 'open' -> OPEN
        """
        normalized = value.lower().strip().replace(' ', '_').replace('-', '_')
        
        # Handle articles_reviewed.csv format
        mapping = {
            'closed': cls.CLOSED,
            'partially_closed': cls.MOSTLY_CLOSED,
            'mostly_closed': cls.MOSTLY_CLOSED,
            'partially_open': cls.MOSTLY_OPEN,
            'mostly_open': cls.MOSTLY_OPEN,
            'open': cls.OPEN,
        }
        
        if normalized in mapping:
            return mapping[normalized]
        
        raise ValueError(f"Unknown openness category: {value}")
    
    def __lt__(self, other: 'OpennessCategory') -> bool:
        """Compare categories by openness level (closed < mostly_closed < mostly_open < open)."""
        order = [self.CLOSED, self.MOSTLY_CLOSED, self.MOSTLY_OPEN, self.OPEN]
        return order.index(self) < order.index(other)
    
    def __le__(self, other: 'OpennessCategory') -> bool:
        return self == other or self < other

In [None]:
# Test OpennessCategory
assert OpennessCategory.from_string("Closed") == OpennessCategory.CLOSED
assert OpennessCategory.from_string("Partially Closed") == OpennessCategory.MOSTLY_CLOSED
assert OpennessCategory.from_string("open") == OpennessCategory.OPEN
assert OpennessCategory.CLOSED < OpennessCategory.OPEN
print("OpennessCategory tests passed!")

In [None]:
#| export
class ClassificationType(str, Enum):
    """Type of availability statement being classified."""
    DATA = "data"
    CODE = "code"

In [None]:
#| export
class LLMProviderType(str, Enum):
    """Supported LLM provider types."""
    CLAUDE = "claude"
    OPENAI = "openai"
    OLLAMA = "ollama"

In [None]:
#| export
class BatchStatus(str, Enum):
    """Status of a batch processing job."""
    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"

## Error Classes

Custom exceptions for better error handling and debugging.

In [None]:
#| export
class ClassificationError(Exception):
    """Base exception for classification errors."""
    pass

class LLMError(ClassificationError):
    """Error from LLM provider (API failure, rate limit, etc.)."""
    def __init__(self, message: str, provider: str = None, retryable: bool = False):
        super().__init__(message)
        self.provider = provider
        self.retryable = retryable

class ConfigurationError(ClassificationError):
    """Error in configuration (missing API key, invalid settings)."""
    pass

class DataError(ClassificationError):
    """Error in data loading or processing."""
    pass

class ValidationError(ClassificationError):
    """Error in validation (invalid category, missing ground truth)."""
    pass

## Data Classes

### LLMConfiguration

Tracks language model configuration for reproducibility (FAIR principles).

In [None]:
#| export
@dataclass
class LLMConfiguration:
    """Configuration for LLM provider, tracked for reproducibility.
    
    Attributes:
        provider: LLM provider type (claude, openai, ollama)
        model_name: Model identifier (e.g., 'claude-3-5-sonnet-20241022')
        temperature: Sampling temperature (default: 0.1 for consistency)
        max_tokens: Maximum response tokens (default: 500)
        top_p: Nucleus sampling parameter (default: 0.95)
        api_endpoint: Optional custom API endpoint (for Ollama)
        api_key_hash: SHA-256 hash of API key for audit trail (never store key itself)
    """
    provider: LLMProviderType
    model_name: str
    temperature: float = 0.1
    max_tokens: int = 500
    top_p: float = 0.95
    api_endpoint: Optional[str] = None
    api_key_hash: Optional[str] = None
    configuration_timestamp: datetime = field(default_factory=datetime.utcnow)
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization."""
        return {
            'provider': self.provider.value,
            'model_name': self.model_name,
            'temperature': self.temperature,
            'max_tokens': self.max_tokens,
            'top_p': self.top_p,
            'api_endpoint': self.api_endpoint,
            'api_key_hash': self.api_key_hash,
            'configuration_timestamp': self.configuration_timestamp.isoformat(),
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> 'LLMConfiguration':
        """Create from dictionary."""
        return cls(
            provider=LLMProviderType(data['provider']),
            model_name=data['model_name'],
            temperature=data.get('temperature', 0.1),
            max_tokens=data.get('max_tokens', 500),
            top_p=data.get('top_p', 0.95),
            api_endpoint=data.get('api_endpoint'),
            api_key_hash=data.get('api_key_hash'),
            configuration_timestamp=datetime.fromisoformat(data['configuration_timestamp'])
                if 'configuration_timestamp' in data else datetime.utcnow(),
        )
    
    def to_json(self) -> str:
        """Serialize to JSON string."""
        return json.dumps(self.to_dict(), indent=2)
    
    @staticmethod
    def hash_api_key(api_key: str) -> str:
        """Create SHA-256 hash of API key for audit trail."""
        return hashlib.sha256(api_key.encode()).hexdigest()[:16]

### Classification

Result of classifying a single statement.

In [None]:
#| export
@dataclass
class Classification:
    """Result of classifying a data or code availability statement.
    
    Attributes:
        category: The classified openness category
        statement_type: Whether this is a data or code classification
        confidence_score: Model confidence (0-1), higher is more confident
        reasoning: Optional chain-of-thought reasoning from LLM
        timestamp: When classification was made (UTC)
        model_config: LLM configuration used for reproducibility
        few_shot_example_ids: IDs of training examples used in prompt
    """
    category: OpennessCategory
    statement_type: ClassificationType
    confidence_score: float = 0.8
    reasoning: Optional[str] = None
    timestamp: datetime = field(default_factory=datetime.utcnow)
    model_config: Optional[LLMConfiguration] = None
    few_shot_example_ids: List[str] = field(default_factory=list)
    
    def __post_init__(self):
        """Validate confidence score is in range [0, 1]."""
        if not 0 <= self.confidence_score <= 1:
            raise ValueError(f"confidence_score must be between 0 and 1, got {self.confidence_score}")
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization."""
        return {
            'category': self.category.value,
            'statement_type': self.statement_type.value,
            'confidence_score': self.confidence_score,
            'reasoning': self.reasoning,
            'timestamp': self.timestamp.isoformat(),
            'model_config': self.model_config.to_dict() if self.model_config else None,
            'few_shot_example_ids': self.few_shot_example_ids,
        }

## LLM Provider Abstraction

Using LiteLLM for unified interface across Claude, OpenAI, and Ollama.

In [None]:
#| export
class LLMProvider:
    """Unified LLM provider interface using LiteLLM.
    
    Supports Claude, OpenAI, and Ollama with consistent interface.
    Includes retry logic with exponential backoff for transient errors.
    
    Example:
        >>> config = LLMConfiguration(
        ...     provider=LLMProviderType.CLAUDE,
        ...     model_name='claude-3-5-sonnet-20241022'
        ... )
        >>> provider = LLMProvider(config)
        >>> response = provider.complete("Classify this statement...")
    """
    
    def __init__(self, config: LLMConfiguration):
        self.config = config
        self._setup_provider()
    
    def _setup_provider(self):
        """Configure the LLM provider based on config."""
        # LiteLLM uses model prefixes to route to providers
        # claude/ for Anthropic, gpt- for OpenAI, ollama/ for Ollama
        if self.config.provider == LLMProviderType.CLAUDE:
            self.model_id = self.config.model_name
            if not self.model_id.startswith('claude'):
                self.model_id = f"claude/{self.model_id}"
        elif self.config.provider == LLMProviderType.OPENAI:
            self.model_id = self.config.model_name
        elif self.config.provider == LLMProviderType.OLLAMA:
            self.model_id = f"ollama/{self.config.model_name}"
            if self.config.api_endpoint:
                os.environ['OLLAMA_API_BASE'] = self.config.api_endpoint
        else:
            raise ConfigurationError(f"Unknown provider: {self.config.provider}")
    
    def complete(
        self, 
        prompt: str,
        max_retries: int = 3,
        retry_delay: float = 1.0
    ) -> str:
        """Generate completion for the given prompt.
        
        Args:
            prompt: The prompt to complete
            max_retries: Maximum number of retry attempts
            retry_delay: Initial delay between retries (doubles each retry)
            
        Returns:
            The model's response text
            
        Raises:
            LLMError: If all retries fail
        """
        import litellm
        
        last_error = None
        delay = retry_delay
        
        for attempt in range(max_retries + 1):
            try:
                response = litellm.completion(
                    model=self.model_id,
                    messages=[{"role": "user", "content": prompt}],
                    temperature=self.config.temperature,
                    max_tokens=self.config.max_tokens,
                    top_p=self.config.top_p,
                )
                return response.choices[0].message.content
            except Exception as e:
                last_error = e
                error_str = str(e).lower()
                
                # Check if error is retryable
                retryable = any(x in error_str for x in [
                    'rate limit', 'timeout', 'overloaded', 
                    'service unavailable', '529', '503', '504'
                ])
                
                if retryable and attempt < max_retries:
                    logging.warning(
                        f"LLM request failed (attempt {attempt + 1}/{max_retries + 1}): {e}. "
                        f"Retrying in {delay:.1f}s..."
                    )
                    time.sleep(delay)
                    delay *= 2  # Exponential backoff
                else:
                    break
        
        raise LLMError(
            f"LLM request failed after {max_retries + 1} attempts: {last_error}",
            provider=self.config.provider.value,
            retryable=False
        )

## Classification Logging

JSON Lines format logging for reproducibility tracking.

In [None]:
#| export
class ClassificationLogger:
    """Logger for classification decisions in JSON Lines format.
    
    Logs all classification decisions with full metadata for:
    - Reproducibility (FAIR principles)
    - Audit trail
    - Debugging and analysis
    
    Example:
        >>> logger = ClassificationLogger('logs/classifications.jsonl')
        >>> logger.log_classification(
        ...     publication_id='doi:10.1234/example',
        ...     classification=classification_result,
        ...     statement_text='Data available at Zenodo...'
        ... )
    """
    
    def __init__(self, log_path: str | Path):
        self.log_path = Path(log_path)
        self.log_path.parent.mkdir(parents=True, exist_ok=True)
    
    def log_classification(
        self,
        publication_id: str,
        classification: Classification,
        statement_text: str,
        extra: Optional[Dict[str, Any]] = None
    ):
        """Log a classification decision.
        
        Args:
            publication_id: Unique identifier for the publication
            classification: The classification result
            statement_text: The original statement text
            extra: Optional additional metadata
        """
        log_entry = {
            'timestamp': datetime.utcnow().isoformat(),
            'publication_id': publication_id,
            'statement_type': classification.statement_type.value,
            'statement_text': statement_text[:500],  # Truncate for log size
            'classification': classification.to_dict(),
        }
        
        if extra:
            log_entry['extra'] = extra
        
        with open(self.log_path, 'a') as f:
            f.write(json.dumps(log_entry) + '\n')
    
    def log_error(
        self,
        publication_id: str,
        error: Exception,
        context: Optional[Dict[str, Any]] = None
    ):
        """Log a classification error."""
        log_entry = {
            'timestamp': datetime.utcnow().isoformat(),
            'publication_id': publication_id,
            'error': True,
            'error_type': type(error).__name__,
            'error_message': str(error),
        }
        
        if context:
            log_entry['context'] = context
        
        with open(self.log_path, 'a') as f:
            f.write(json.dumps(log_entry) + '\n')

In [None]:
# Test Classification dataclass
test_config = LLMConfiguration(
    provider=LLMProviderType.CLAUDE,
    model_name='claude-3-5-sonnet-20241022'
)

test_classification = Classification(
    category=OpennessCategory.OPEN,
    statement_type=ClassificationType.DATA,
    confidence_score=0.92,
    reasoning="Data is available on Zenodo with no restrictions.",
    model_config=test_config
)

print(f"Classification: {test_classification.category.value}")
print(f"Confidence: {test_classification.confidence_score}")
print("Classification tests passed!")

In [None]:
#| hide
# nbdev requires this cell to export
import nbdev; nbdev.nbdev_export()