# 04_newdata_inference — Multi-Model Inference with Auto-Detection

This notebook supports inference using multiple models:
- **CNN Model**: If `artifacts/cnn1d.pt` + `artifacts/calibrator.joblib` exist
- **XGBoost Model**: If `artifacts/xgboost_model.pkl` exists
- **Other Models**: Extensible to support additional models

Output format remains consistent regardless of the model used.

In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
import yaml
from datetime import datetime
warnings.filterwarnings('ignore')

# Try importing optional dependencies
try:
    import torch
    torch_available = True
except ImportError:
    torch = None
    torch_available = False
    print("⚠ PyTorch not available; CNN inference disabled.")

try:
    import xgboost as xgb
    xgb_available = True
except ImportError:
    xgb = None
    xgb_available = False
    print("⚠ XGBoost not available; XGBoost inference disabled.")

import joblib
import sys
sys.path.append('..')

# Check device availability for PyTorch
if torch_available:
    if torch.cuda.is_available():
        device = 'cuda'
        print(f"✓ Using GPU: {torch.cuda.get_device_name(0)}")
    elif torch.backends.mps.is_available():
        device = 'mps'
        print("✓ Using Apple MPS")
    else:
        device = 'cpu'
        print("✓ Using CPU for PyTorch")
else:
    device = 'cpu'
    
print(f"\nAvailable backends:")
print(f"  PyTorch: {torch_available}")
print(f"  XGBoost: {xgb_available}")

## 1. Model Detection and Loading

In [None]:
class ModelLoader:
    """Unified model loader with automatic detection."""
    
    def __init__(self, artifacts_dir='../artifacts'):
        self.artifacts_dir = Path(artifacts_dir)
        self.models = {}
        self.active_model = None
        
    def detect_available_models(self):
        """Detect which models are available."""
        available = []
        
        # Check for CNN model
        cnn_path = self.artifacts_dir / 'cnn1d.pt'
        cal_path = self.artifacts_dir / 'calibrator.joblib'
        if cnn_path.exists() and cal_path.exists() and torch_available:
            available.append('CNN')
            print(f"✓ CNN model found: {cnn_path}")
            
        # Check for XGBoost model
        xgb_paths = [
            self.artifacts_dir / 'xgboost_model.pkl',
            self.artifacts_dir / 'xgb_model.pkl',
            self.artifacts_dir / 'model_xgb.pkl'
        ]
        for xgb_path in xgb_paths:
            if xgb_path.exists() and xgb_available:
                available.append('XGBoost')
                print(f"✓ XGBoost model found: {xgb_path}")
                break
                
        # Check for scikit-learn models
        sklearn_paths = [
            self.artifacts_dir / 'sklearn_model.pkl',
            self.artifacts_dir / 'rf_model.pkl',
            self.artifacts_dir / 'svm_model.pkl'
        ]
        for sklearn_path in sklearn_paths:
            if sklearn_path.exists():
                available.append('sklearn')
                print(f"✓ Scikit-learn model found: {sklearn_path}")
                break
                
        return available
    
    def load_cnn_model(self):
        """Load CNN model and calibrator."""
        if not torch_available:
            raise RuntimeError("PyTorch not available")
            
        try:
            from app.models.cnn1d import make_model
            
            # Load model
            model = make_model()
            model_path = self.artifacts_dir / 'cnn1d.pt'
            model.load_state_dict(torch.load(model_path, map_location=device))
            model.to(device)
            model.eval()
            
            # Load calibrator
            cal_path = self.artifacts_dir / 'calibrator.joblib'
            calibrator = joblib.load(cal_path)
            
            self.models['CNN'] = {
                'model': model,
                'calibrator': calibrator,
                'type': 'CNN'
            }
            
            print("✓ CNN model loaded successfully")
            return True
            
        except Exception as e:
            print(f"✗ Failed to load CNN model: {e}")
            return False
    
    def load_xgboost_model(self):
        """Load XGBoost model."""
        if not xgb_available:
            raise RuntimeError("XGBoost not available")
            
        try:
            xgb_paths = [
                self.artifacts_dir / 'xgboost_model.pkl',
                self.artifacts_dir / 'xgb_model.pkl',
                self.artifacts_dir / 'model_xgb.pkl'
            ]
            
            for xgb_path in xgb_paths:
                if xgb_path.exists():
                    model = joblib.load(xgb_path)
                    
                    # Check for separate calibrator
                    cal_path = self.artifacts_dir / 'xgb_calibrator.joblib'
                    calibrator = None
                    if cal_path.exists():
                        calibrator = joblib.load(cal_path)
                    
                    self.models['XGBoost'] = {
                        'model': model,
                        'calibrator': calibrator,
                        'type': 'XGBoost'
                    }
                    
                    print("✓ XGBoost model loaded successfully")
                    return True
                    
        except Exception as e:
            print(f"✗ Failed to load XGBoost model: {e}")
            return False
            
    def load_best_model(self, preference=['CNN', 'XGBoost', 'sklearn']):
        """Load the best available model based on preference."""
        available = self.detect_available_models()
        
        if not available:
            print("\n✗ No models found in artifacts directory")
            return None
            
        # Load based on preference
        for model_type in preference:
            if model_type in available:
                if model_type == 'CNN':
                    if self.load_cnn_model():
                        self.active_model = 'CNN'
                        break
                elif model_type == 'XGBoost':
                    if self.load_xgboost_model():
                        self.active_model = 'XGBoost'
                        break
                        
        return self.active_model

# Initialize model loader
loader = ModelLoader()
active_model = loader.load_best_model()

if active_model:
    print(f"\n✓ Active model: {active_model}")
else:
    print("\n⚠ No models available, generating mock predictions")

## 2. Data Loading and Preprocessing

In [None]:
def load_test_data():
    """Load test data from various sources."""
    
    # Option 1: Load from processed test set
    test_path = Path('../artifacts/test_set.npz')
    if test_path.exists():
        print("Loading test data from artifacts...")
        data = np.load(test_path)
        return data['time'], data['flux'], data.get('labels')
    
    # Option 2: Load from TLS search results
    tls_path = Path('../artifacts/denoised_lc.npz')
    if tls_path.exists():
        print("Loading data from TLS search results...")
        data = np.load(tls_path)
        return data['time'], data['flux_denoised'], None
    
    # Option 3: Generate synthetic data
    print("Generating synthetic test data...")
    return generate_synthetic_test_data()

def generate_synthetic_test_data(n_samples=20):
    """Generate synthetic light curves for testing."""
    samples = []
    
    for i in range(n_samples):
        has_transit = i % 2 == 0
        
        # Time array
        n_points = 2048
        t = np.linspace(0, 27.4, n_points)
        
        # Generate flux
        flux = np.ones_like(t)
        
        # Transit parameters
        period = np.random.uniform(2, 10)
        t0 = np.random.uniform(0, 2)
        duration = np.random.uniform(0.05, 0.15)
        
        if has_transit:
            depth = np.random.uniform(0.0005, 0.002)
            phase = ((t - t0) / period) % 1.0
            in_transit = phase < (duration / period)
            flux[in_transit] -= depth
        
        # Add noise
        flux += np.random.normal(0, 5e-4, size=flux.shape)
        
        samples.append({
            'time': t,
            'flux': flux,
            'period': period,
            't0': t0,
            'duration': duration,
            'label': 1 if has_transit else 0,
            'id': f'TIC_{100000 + i}'
        })
    
    return samples

# Load or generate test data
test_samples = generate_synthetic_test_data()
print(f"\nLoaded {len(test_samples)} test samples")
print(f"Transit samples: {sum(s['label'] for s in test_samples)}")
print(f"Non-transit samples: {sum(1-s['label'] for s in test_samples)}")

## 3. Inference Functions

In [None]:
def run_cnn_inference(model_dict, samples):
    """Run inference using CNN model."""
    from app.data.fold import make_views
    
    model = model_dict['model']
    calibrator = model_dict['calibrator']
    predictions = []
    
    with torch.no_grad():
        for sample in samples:
            # Create global and local views
            g_view, l_view = make_views(
                sample['time'],
                sample['flux'],
                sample['period'],
                sample['t0'],
                sample['duration']
            )
            
            # Convert to tensors
            G = torch.tensor(g_view, dtype=torch.float32).unsqueeze(0).to(device)
            L = torch.tensor(l_view, dtype=torch.float32).unsqueeze(0).to(device)
            
            # Get prediction
            logits = model(G, L).squeeze()
            prob = torch.sigmoid(logits).cpu().numpy()
            
            # Apply calibration
            if calibrator is not None:
                prob = calibrator.transform([[prob]])[0][0]
            
            predictions.append(float(prob))
    
    return np.array(predictions)

def run_xgboost_inference(model_dict, samples):
    """Run inference using XGBoost model."""
    model = model_dict['model']
    calibrator = model_dict.get('calibrator')
    predictions = []
    
    for sample in samples:
        # Extract features for XGBoost
        features = extract_features_for_xgb(sample)
        
        # Get prediction
        if hasattr(model, 'predict_proba'):
            prob = model.predict_proba([features])[0][1]
        else:
            # For XGBoost native API
            import xgboost as xgb
            dtest = xgb.DMatrix([features])
            prob = model.predict(dtest)[0]
        
        # Apply calibration if available
        if calibrator is not None:
            prob = calibrator.transform([[prob]])[0][0]
            
        predictions.append(float(prob))
    
    return np.array(predictions)

def extract_features_for_xgb(sample):
    """Extract features for XGBoost model."""
    flux = sample['flux']
    
    # Basic statistics
    features = [
        np.mean(flux),
        np.std(flux),
        np.min(flux),
        np.max(flux),
        np.median(flux),
        np.percentile(flux, 25),
        np.percentile(flux, 75),
        sample['period'],
        sample['duration'],
    ]
    
    # Add more sophisticated features if needed
    # - Fourier components
    # - Autocorrelation
    # - etc.
    
    return features

def run_inference(loader, samples):
    """Run inference with the active model."""
    if loader.active_model == 'CNN':
        print("Running CNN inference...")
        return run_cnn_inference(loader.models['CNN'], samples)
    elif loader.active_model == 'XGBoost':
        print("Running XGBoost inference...")
        return run_xgboost_inference(loader.models['XGBoost'], samples)
    else:
        print("Running mock inference...")
        # Mock predictions
        return np.random.random(len(samples))

## 4. Run Inference on Test Data

In [None]:
# Run inference
if loader.active_model:
    predictions = run_inference(loader, test_samples)
else:
    print("⚠ Using random predictions (no model available)")
    predictions = np.random.random(len(test_samples))

# Get true labels
true_labels = np.array([s['label'] for s in test_samples])
sample_ids = [s['id'] for s in test_samples]

# Create results dataframe
results_df = pd.DataFrame({
    'id': sample_ids,
    'probability': predictions,
    'prediction': (predictions > 0.5).astype(int),
    'true_label': true_labels
})

# Calculate correctness
results_df['correct'] = results_df['prediction'] == results_df['true_label']

print("\n" + "="*60)
print("Inference Results:")
print("="*60)
print(results_df.to_string(index=False))

# Summary statistics
accuracy = results_df['correct'].mean()
print("\n" + "-"*60)
print(f"Accuracy: {accuracy:.2%}")
print(f"Predictions: {results_df['prediction'].sum()} planets detected")
print(f"Ground truth: {results_df['true_label'].sum()} actual planets")

## 5. Performance Metrics

In [None]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix
)

# Calculate comprehensive metrics
metrics = {
    'model_type': loader.active_model or 'Mock',
    'n_samples': len(test_samples),
    'accuracy': accuracy_score(true_labels, results_df['prediction']),
    'precision': precision_score(true_labels, results_df['prediction'], zero_division=0),
    'recall': recall_score(true_labels, results_df['prediction'], zero_division=0),
    'f1_score': f1_score(true_labels, results_df['prediction'], zero_division=0),
}

# Add ROC-AUC and PR-AUC if we have both classes
if len(np.unique(true_labels)) > 1:
    metrics['roc_auc'] = roc_auc_score(true_labels, predictions)
    metrics['pr_auc'] = average_precision_score(true_labels, predictions)

# Confusion matrix
cm = confusion_matrix(true_labels, results_df['prediction'])

print("\nPerformance Metrics:")
print("=" * 40)
for metric, value in metrics.items():
    if isinstance(value, float):
        print(f"{metric:15s}: {value:.4f}")
    else:
        print(f"{metric:15s}: {value}")

print("\nConfusion Matrix:")
print("                 Predicted")
print("                 No    Yes")
print(f"Actual No    {cm[0,0]:5d} {cm[0,1]:5d}")
print(f"       Yes   {cm[1,0]:5d} {cm[1,1]:5d}")

## 6. Visualization

In [None]:
# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 1. Probability distribution by class
transit_probs = predictions[true_labels == 1]
no_transit_probs = predictions[true_labels == 0]

axes[0,0].hist(no_transit_probs, bins=20, alpha=0.5, label='No Transit', color='blue', edgecolor='black')
axes[0,0].hist(transit_probs, bins=20, alpha=0.5, label='Transit', color='red', edgecolor='black')
axes[0,0].axvline(x=0.5, color='black', linestyle='--', label='Threshold')
axes[0,0].set_xlabel('Predicted Probability')
axes[0,0].set_ylabel('Count')
axes[0,0].set_title('Distribution of Predicted Probabilities')
axes[0,0].legend()
axes[0,0].grid(True, alpha=0.3)

# 2. ROC Curve
if len(np.unique(true_labels)) > 1:
    from sklearn.metrics import roc_curve, auc
    fpr, tpr, _ = roc_curve(true_labels, predictions)
    roc_auc = auc(fpr, tpr)
    
    axes[0,1].plot(fpr, tpr, 'b-', linewidth=2, label=f'ROC (AUC = {roc_auc:.3f})')
    axes[0,1].plot([0, 1], [0, 1], 'k--', alpha=0.3)
    axes[0,1].set_xlabel('False Positive Rate')
    axes[0,1].set_ylabel('True Positive Rate')
    axes[0,1].set_title('ROC Curve')
    axes[0,1].legend(loc='lower right')
    axes[0,1].grid(True, alpha=0.3)

# 3. Precision-Recall Curve
if len(np.unique(true_labels)) > 1:
    from sklearn.metrics import precision_recall_curve, average_precision_score
    precision, recall, _ = precision_recall_curve(true_labels, predictions)
    pr_auc = average_precision_score(true_labels, predictions)
    
    axes[1,0].plot(recall, precision, 'r-', linewidth=2, label=f'PR (AP = {pr_auc:.3f})')
    axes[1,0].set_xlabel('Recall')
    axes[1,0].set_ylabel('Precision')
    axes[1,0].set_title('Precision-Recall Curve')
    axes[1,0].legend(loc='lower left')
    axes[1,0].grid(True, alpha=0.3)

# 4. Scatter plot of predictions
colors = ['blue' if t == 0 else 'red' for t in true_labels]
axes[1,1].scatter(range(len(predictions)), predictions, c=colors, alpha=0.6, s=50)
axes[1,1].axhline(y=0.5, color='black', linestyle='--', alpha=0.7)
axes[1,1].set_xlabel('Sample Index')
axes[1,1].set_ylabel('Predicted Probability')
axes[1,1].set_title('Predictions by Sample')
axes[1,1].grid(True, alpha=0.3)

# Add legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='red', label='True Transit'),
    Patch(facecolor='blue', label='True Non-Transit')
]
axes[1,1].legend(handles=legend_elements)

plt.suptitle(f'Model Performance: {loader.active_model or "Mock"}', fontsize=14, fontweight='bold')
plt.tight_layout()

# Save figure
reports_dir = Path('../reports')
reports_dir.mkdir(exist_ok=True)
plt.savefig(reports_dir / 'inference_visualization.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\n✓ Visualization saved to {reports_dir / 'inference_visualization.png'}")

## 7. Export Results

In [None]:
# Prepare output directory
outputs_dir = Path('../outputs')
outputs_dir.mkdir(exist_ok=True)

# Generate timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Save candidates CSV
candidates_df = results_df[results_df['prediction'] == 1][['id', 'probability']].copy()
candidates_df = candidates_df.rename(columns={'probability': 'confidence'})
candidates_df['model'] = loader.active_model or 'Mock'

candidates_file = outputs_dir / f'candidates_{timestamp}.csv'
candidates_df.to_csv(candidates_file, index=False)
print(f"\n✓ Saved {len(candidates_df)} candidates to {candidates_file}")

# Save all predictions
all_predictions_file = outputs_dir / f'all_predictions_{timestamp}.csv'
results_df.to_csv(all_predictions_file, index=False)
print(f"✓ Saved all predictions to {all_predictions_file}")

# Save provenance YAML
provenance = {
    'timestamp': timestamp,
    'model': {
        'type': loader.active_model or 'Mock',
        'device': device if loader.active_model == 'CNN' else 'cpu',
        'artifacts_dir': str(loader.artifacts_dir)
    },
    'data': {
        'n_samples': len(test_samples),
        'n_transits': int(true_labels.sum()),
        'n_non_transits': int((1-true_labels).sum())
    },
    'results': {
        'n_candidates': len(candidates_df),
        'metrics': {k: float(v) if isinstance(v, (np.floating, float)) else v 
                   for k, v in metrics.items()}
    }
}

provenance_file = outputs_dir / f'provenance_{timestamp}.yaml'
with open(provenance_file, 'w') as f:
    yaml.dump(provenance, f, default_flow_style=False)
print(f"✓ Saved provenance to {provenance_file}")

# Save metrics JSON for dashboard
metrics_file = reports_dir / f'inference_metrics_{loader.active_model or "mock"}.json'
with open(metrics_file, 'w') as f:
    json.dump(metrics, f, indent=2)
print(f"✓ Saved metrics to {metrics_file}")

## 8. Summary Report

In [None]:
print("\n" + "="*60)
print("INFERENCE SUMMARY REPORT")
print("="*60)
print(f"\nModel Information:")
print(f"  Type: {loader.active_model or 'Mock'}")
print(f"  Device: {device if loader.active_model == 'CNN' else 'CPU'}")
print(f"  Artifacts: {loader.artifacts_dir}")

print(f"\nDataset:")
print(f"  Total samples: {len(test_samples)}")
print(f"  Transit samples: {true_labels.sum()}")
print(f"  Non-transit samples: {(1-true_labels).sum()}")

print(f"\nPrediction Results:")
print(f"  Candidates detected: {len(candidates_df)}")
print(f"  True positives: {((results_df['prediction'] == 1) & (results_df['true_label'] == 1)).sum()}")
print(f"  False positives: {((results_df['prediction'] == 1) & (results_df['true_label'] == 0)).sum()}")
print(f"  True negatives: {((results_df['prediction'] == 0) & (results_df['true_label'] == 0)).sum()}")
print(f"  False negatives: {((results_df['prediction'] == 0) & (results_df['true_label'] == 1)).sum()}")

print(f"\nPerformance:")
print(f"  Accuracy: {metrics['accuracy']:.2%}")
if 'roc_auc' in metrics:
    print(f"  ROC-AUC: {metrics['roc_auc']:.3f}")
if 'pr_auc' in metrics:
    print(f"  PR-AUC: {metrics['pr_auc']:.3f}")
print(f"  Precision: {metrics['precision']:.3f}")
print(f"  Recall: {metrics['recall']:.3f}")
print(f"  F1 Score: {metrics['f1_score']:.3f}")

print(f"\nOutput Files:")
print(f"  Candidates: {candidates_file}")
print(f"  All predictions: {all_predictions_file}")
print(f"  Provenance: {provenance_file}")
print(f"  Metrics: {metrics_file}")
print(f"  Visualization: {reports_dir / 'inference_visualization.png'}")

print("\n" + "="*60)
print("✓ Inference pipeline completed successfully!")
print("="*60)

## Summary

This notebook provides a robust multi-model inference pipeline that:

1. **Auto-detects available models**: CNN, XGBoost, or other models
2. **Loads the best available model**: Based on user preference
3. **Runs appropriate inference**: Different pipelines for different model types
4. **Provides consistent output**: Same format regardless of model used
5. **Comprehensive evaluation**: Metrics, visualizations, and exports

The system gracefully handles missing models and dependencies, ensuring that inference can always proceed even if the preferred model is not available.