# Batch Processing Module

> Batch classification of multiple publications from CSV files.

This module provides:
- `classify_csv()`: Process a CSV file with multiple publications
- `BatchJob`: Track batch processing progress and statistics
- Error handling with skip/fail/log options

In [None]:
#| default_exp batch

In [None]:
#| export
from __future__ import annotations
import pandas as pd
import uuid
from pathlib import Path
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional, List, Tuple, Callable, Literal
import logging
import json

from openness_classifier.core import (
    OpennessCategory,
    ClassificationType,
    Classification,
    BatchStatus,
    ClassificationError,
    DataError,
)
from openness_classifier.config import ClassifierConfig, load_config
from openness_classifier.data import Publication
from openness_classifier.classifier import OpennessClassifier, get_classifier

## BatchJob

Tracks batch processing progress and statistics.

In [None]:
#| export
@dataclass
class BatchJob:
    """Tracks batch processing of multiple publications.
    
    Attributes:
        job_id: Unique identifier for this job
        input_file: Path to input CSV
        output_file: Path to output CSV
        total_publications: Number of publications to process
        processed_count: Number processed so far
        failed_count: Number that failed
        status: Current job status
        start_time: When job started
        end_time: When job finished (if complete)
        error_log: List of (publication_id, error_message) tuples
    """
    job_id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
    input_file: Optional[Path] = None
    output_file: Optional[Path] = None
    total_publications: int = 0
    processed_count: int = 0
    failed_count: int = 0
    skipped_count: int = 0
    status: BatchStatus = BatchStatus.PENDING
    start_time: Optional[datetime] = None
    end_time: Optional[datetime] = None
    error_log: List[Tuple[str, str]] = field(default_factory=list)
    
    @property
    def success_count(self) -> int:
        """Number of successful classifications."""
        return self.processed_count - self.failed_count - self.skipped_count
    
    @property
    def progress_percent(self) -> float:
        """Progress as percentage."""
        if self.total_publications == 0:
            return 0.0
        return 100.0 * self.processed_count / self.total_publications
    
    @property
    def duration_seconds(self) -> Optional[float]:
        """Job duration in seconds."""
        if not self.start_time:
            return None
        end = self.end_time or datetime.now()
        return (end - self.start_time).total_seconds()
    
    def to_dict(self) -> dict:
        """Convert to dictionary for serialization."""
        return {
            'job_id': self.job_id,
            'input_file': str(self.input_file) if self.input_file else None,
            'output_file': str(self.output_file) if self.output_file else None,
            'total_publications': self.total_publications,
            'processed_count': self.processed_count,
            'success_count': self.success_count,
            'failed_count': self.failed_count,
            'skipped_count': self.skipped_count,
            'status': self.status.value,
            'start_time': self.start_time.isoformat() if self.start_time else None,
            'end_time': self.end_time.isoformat() if self.end_time else None,
            'duration_seconds': self.duration_seconds,
            'error_count': len(self.error_log),
        }
    
    def summary(self) -> str:
        """Human-readable summary of job status."""
        lines = [
            f"BatchJob {self.job_id}",
            f"Status: {self.status.value}",
            f"Progress: {self.processed_count}/{self.total_publications} ({self.progress_percent:.1f}%)",
            f"Success: {self.success_count}, Failed: {self.failed_count}, Skipped: {self.skipped_count}",
        ]
        if self.duration_seconds:
            lines.append(f"Duration: {self.duration_seconds:.1f}s")
        return "\n".join(lines)

## Batch Classification

In [None]:
#| export
def classify_csv(
    input_path: str | Path,
    output_path: Optional[str | Path] = None,
    config: Optional[ClassifierConfig] = None,
    id_column: str = 'doi',
    data_statement_column: str = 'data_statement',
    code_statement_column: str = 'code_statement',
    error_handling: Literal['skip', 'fail', 'log'] = 'log',
    progress_callback: Optional[Callable[[int, int], None]] = None,
) -> BatchJob:
    """Classify publications from a CSV file.
    
    Reads publications from input CSV, classifies data and code availability,
    and writes results to output CSV.
    
    Args:
        input_path: Path to input CSV file
        output_path: Path for output CSV (default: input_classified.csv)
        config: Classifier configuration
        id_column: Column name for publication ID
        data_statement_column: Column name for data statements
        code_statement_column: Column name for code statements
        error_handling: How to handle errors:
            - 'skip': Skip failed rows, continue processing
            - 'fail': Stop on first error
            - 'log': Log error and continue (default)
        progress_callback: Optional callback(processed, total) for progress
        
    Returns:
        BatchJob with statistics and any errors
    """
    input_path = Path(input_path)
    if output_path is None:
        output_path = input_path.with_suffix('.classified.csv')
    else:
        output_path = Path(output_path)
    
    # Initialize job
    job = BatchJob(
        input_file=input_path,
        output_file=output_path,
        status=BatchStatus.PENDING,
    )
    
    # Load input CSV
    try:
        df = pd.read_csv(input_path)
    except Exception as e:
        job.status = BatchStatus.FAILED
        job.error_log.append(('_load', str(e)))
        raise DataError(f"Failed to read input CSV: {e}")
    
    job.total_publications = len(df)
    job.status = BatchStatus.RUNNING
    job.start_time = datetime.now()
    
    # Get classifier
    classifier = get_classifier(config)
    
    # Add output columns
    df['data_classification'] = None
    df['data_confidence'] = None
    df['code_classification'] = None
    df['code_confidence'] = None
    
    # Process each row
    for idx, row in df.iterrows():
        pub_id = str(row.get(id_column, idx))
        
        try:
            # Create publication
            pub = Publication(
                id=pub_id,
                data_statement=_get_statement(row, data_statement_column),
                code_statement=_get_statement(row, code_statement_column),
            )
            
            # Classify
            data_cls, code_cls = classifier.classify_publication(pub)
            
            # Store results
            if data_cls:
                df.at[idx, 'data_classification'] = data_cls.category.value
                df.at[idx, 'data_confidence'] = data_cls.confidence_score
            
            if code_cls:
                df.at[idx, 'code_classification'] = code_cls.category.value
                df.at[idx, 'code_confidence'] = code_cls.confidence_score
            
            # Track skipped (no statements)
            if not data_cls and not code_cls:
                job.skipped_count += 1
                
        except Exception as e:
            job.failed_count += 1
            job.error_log.append((pub_id, str(e)))
            
            if error_handling == 'fail':
                job.status = BatchStatus.FAILED
                job.end_time = datetime.now()
                raise
            elif error_handling == 'log':
                logging.error(f"Failed to classify {pub_id}: {e}")
            # 'skip' just continues
        
        job.processed_count += 1
        
        if progress_callback:
            progress_callback(job.processed_count, job.total_publications)
    
    # Save output
    df.to_csv(output_path, index=False)
    
    job.status = BatchStatus.COMPLETED
    job.end_time = datetime.now()
    
    # Log batch job details
    _log_batch_job(job, config)
    
    return job


def _get_statement(row: pd.Series, column: str) -> Optional[str]:
    """Extract statement from row, handling missing values."""
    if column not in row.index:
        return None
    value = row[column]
    if pd.isna(value) or str(value).strip().lower() in ('', 'nothing', 'nan'):
        return None
    return str(value).strip()


def _log_batch_job(job: BatchJob, config: Optional[ClassifierConfig]) -> None:
    """Log batch job details to log directory."""
    if config:
        log_path = config.log_dir / f"batch_{job.job_id}.json"
    else:
        log_path = Path('logs') / f"batch_{job.job_id}.json"
    
    log_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(log_path, 'w') as f:
        json.dump(job.to_dict(), f, indent=2)

In [None]:
# Test BatchJob
job = BatchJob(
    total_publications=100,
    processed_count=50,
    failed_count=2,
    skipped_count=5,
    status=BatchStatus.RUNNING
)

print(f"Success count: {job.success_count}")
print(f"Progress: {job.progress_percent}%")
print(job.summary())

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()