# Exoplanet Detection using Machine Learning on Transit Data

This notebook demonstrates how to use machine learning techniques to detect exoplanets from transit light curves using the ExoplanetHunter-TransitML package.

## Overview

We'll walk through the complete pipeline:
1. Loading transit data using lightkurve
2. Preprocessing the light curves
3. Creating features for machine learning
4. Training ML models for transit detection
5. Validating results using time series cross-validation

## Requirements

Make sure you have installed all required packages:
```bash
pip install -r requirements.txt
```

In [5]:
import sys
print(sys.executable)

C:\Users\goran.backlund\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\python.exe


In [1]:
# Import required libraries
import numpy as np

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import sys

# Add src directory to path
sys.path.append('../src')

# Import our custom modules
from exoplanet_hunter import (
    TransitDataLoader,
    TransitPreprocessor,
    TransitClassifier,
    GroupedTimeSeriesValidator
)

# Set up plotting
plt.style.use('seaborn-v0_8')
plt.rcParams['figure.figsize'] = (12, 8)

print("All imports successful!")

All imports successful!


## 1. Data Loading

Let's start by loading some transit data using lightkurve. We'll use known exoplanet host stars for our demonstration.

In [None]:
# Initialize data loader
data_loader = TransitDataLoader(cache_dir='../data/cache')

# Get list of known exoplanet targets
targets = data_loader.get_known_exoplanet_targets(max_targets=3)
print(f"Selected targets: {targets}")

# Load light curves for these targets
light_curves = data_loader.load_multiple_targets(targets, mission="TESS")

print(f"\nSuccessfully loaded {len(light_curves)} light curves")
for target, lc in light_curves.items():
    print(f"  {target}: {len(lc)} data points")

In [None]:
# Plot raw light curves
fig, axes = plt.subplots(len(light_curves), 1, figsize=(15, 4*len(light_curves)))
if len(light_curves) == 1:
    axes = [axes]

for i, (target, lc) in enumerate(light_curves.items()):
    lc.plot(ax=axes[i])
    axes[i].set_title(f"Raw Light Curve: {target}")
    axes[i].set_ylabel("Flux")

plt.tight_layout()
plt.show()

## 2. Data Preprocessing

Now let's preprocess the light curves to prepare them for machine learning analysis.

In [None]:
# Initialize preprocessor
preprocessor = TransitPreprocessor(
    detrend_method="biweight",
    outlier_sigma=5.0,
    normalization_method="robust"
)

# Process each light curve
processed_data = {}

for target, lc in light_curves.items():
    print(f"\nProcessing {target}...")
    
    # Extract features from light curve
    df = data_loader.extract_features_from_lightcurve(lc)
    
    # Apply preprocessing pipeline
    processed_df = preprocessor.process_light_curve(df, create_features=True)
    
    # Add target identifier
    processed_df['target'] = target
    
    processed_data[target] = processed_df

print("\nPreprocessing completed!")

In [None]:
# Plot processed light curves
fig, axes = plt.subplots(len(processed_data), 2, figsize=(20, 4*len(processed_data)))
if len(processed_data) == 1:
    axes = axes.reshape(1, -1)

for i, (target, df) in enumerate(processed_data.items()):
    # Original flux
    axes[i, 0].scatter(df['time'], df['original_flux'], alpha=0.6, s=1)
    axes[i, 0].set_title(f"Original Flux: {target}")
    axes[i, 0].set_xlabel("Time (days)")
    axes[i, 0].set_ylabel("Flux")
    
    # Processed flux
    axes[i, 1].scatter(df['time'], df['flux'], alpha=0.6, s=1, color='red')
    axes[i, 1].set_title(f"Processed Flux: {target}")
    axes[i, 1].set_xlabel("Time (days)")
    axes[i, 1].set_ylabel("Normalized Flux")

plt.tight_layout()
plt.show()

## 3. Feature Engineering and Label Creation

Let's create labels for transit detection and examine our features.

In [None]:
# Combine all processed data
combined_df = pd.concat(processed_data.values(), ignore_index=True)

print(f"Combined dataset shape: {combined_df.shape}")
print(f"Available columns: {list(combined_df.columns)}")
print(f"\nTargets in dataset: {combined_df['target'].value_counts()}")

In [None]:
# Initialize classifier to create labels
classifier = TransitClassifier(model_type="random_forest")

# Create transit labels for each target
all_labels = []

for target in combined_df['target'].unique():
    target_data = combined_df[combined_df['target'] == target]
    
    # Create labels based on flux dips
    labels = classifier.create_transit_labels(
        target_data['time'].values,
        target_data['flux'].values,
        transit_threshold=0.01,  # 1% flux dip
        window_size=5
    )
    
    all_labels.extend(labels)
    print(f"{target}: {labels.sum()} transits detected out of {len(labels)} points ({labels.mean()*100:.2f}%)")

# Add labels to dataset
combined_df['has_transit'] = all_labels

print(f"\nTotal transits: {combined_df['has_transit'].sum()} out of {len(combined_df)} points")
print(f"Transit rate: {combined_df['has_transit'].mean()*100:.2f}%")

In [None]:
# Visualize detected transits
fig, axes = plt.subplots(len(processed_data), 1, figsize=(15, 4*len(processed_data)))
if len(processed_data) == 1:
    axes = [axes]

for i, target in enumerate(combined_df['target'].unique()):
    target_data = combined_df[combined_df['target'] == target]
    
    # Plot all points
    axes[i].scatter(target_data['time'], target_data['flux'], 
                   c='blue', alpha=0.6, s=1, label='Normal')
    
    # Highlight transits
    transit_data = target_data[target_data['has_transit'] == 1]
    if len(transit_data) > 0:
        axes[i].scatter(transit_data['time'], transit_data['flux'], 
                       c='red', s=10, label='Transit')
    
    axes[i].set_title(f"Transit Detection: {target}")
    axes[i].set_xlabel("Time (days)")
    axes[i].set_ylabel("Normalized Flux")
    axes[i].legend()

plt.tight_layout()
plt.show()

## 4. Machine Learning Model Training

Now let's train different ML models to detect transits.

In [None]:
# Prepare data for training
print("Feature columns available for training:")
feature_cols = [col for col in combined_df.columns if col not in ['time', 'target', 'has_transit', 'original_flux']]
print(feature_cols)

# Check for any missing values
print(f"\nMissing values per column:")
print(combined_df[feature_cols + ['has_transit']].isnull().sum())

In [None]:
# Train different types of models
model_types = ['random_forest', 'gradient_boosting', 'logistic']
trained_models = {}
training_results = {}

for model_type in model_types:
    print(f"\n{'='*50}")
    print(f"Training {model_type.upper()} model")
    print(f"{'='*50}")
    
    # Initialize and train model
    model = TransitClassifier(model_type=model_type)
    results = model.train(combined_df, target_column='has_transit')
    
    trained_models[model_type] = model
    training_results[model_type] = results
    
    print(f"\nTraining Results for {model_type}:")
    print(f"Accuracy: {results['training_accuracy']:.3f}")
    if 'training_auc' in results:
        print(f"AUC: {results['training_auc']:.3f}")
    
    # Show feature importance if available
    if 'feature_importance' in results:
        print("\nTop 5 most important features:")
        print(results['feature_importance'].head())

In [None]:
# Plot feature importance for tree-based models
fig, axes = plt.subplots(1, 2, figsize=(20, 6))

for i, model_type in enumerate(['random_forest', 'gradient_boosting']):
    if model_type in training_results and 'feature_importance' in training_results[model_type]:
        importance_df = training_results[model_type]['feature_importance'].head(10)
        
        axes[i].barh(range(len(importance_df)), importance_df['importance'])
        axes[i].set_yticks(range(len(importance_df)))
        axes[i].set_yticklabels(importance_df['feature'])
        axes[i].set_xlabel('Feature Importance')
        axes[i].set_title(f'Feature Importance - {model_type.title()}')
        axes[i].invert_yaxis()

plt.tight_layout()
plt.show()

## 5. Time Series Cross-Validation

Let's properly validate our models using time series cross-validation to account for the temporal nature of the data.

In [None]:
# Initialize validator
validator = GroupedTimeSeriesValidator(
    n_splits=3,  # Use fewer splits for demonstration
    group_column='target'
)

# Prepare data for validation
X = combined_df[feature_cols].values
y = combined_df['has_transit'].values
groups = combined_df['target'].values

print(f"Validation data shape: X={X.shape}, y={y.shape}")
print(f"Class distribution: {np.bincount(y)}")

In [None]:
# Validate each model
validation_results = {}

for model_type, model in trained_models.items():
    print(f"\n{'='*50}")
    print(f"Validating {model_type.upper()} model")
    print(f"{'='*50}")
    
    # Perform time series cross-validation
    cv_results = validator.validate_model(
        model.model, X, y, groups=groups,
        scoring_metrics=['accuracy', 'precision', 'recall', 'f1', 'roc_auc']
    )
    
    validation_results[model_type] = cv_results
    
    # Print summary
    summary = cv_results['summary']
    print(f"\nCross-Validation Results:")
    for metric in ['accuracy', 'precision', 'recall', 'f1', 'roc_auc']:
        mean_key = f'{metric}_mean'
        std_key = f'{metric}_std'
        if mean_key in summary:
            print(f"{metric.capitalize()}: {summary[mean_key]:.3f} ± {summary[std_key]:.3f}")

In [None]:
# Compare model performance
comparison_data = []

for model_type, results in validation_results.items():
    summary = results['summary']
    comparison_data.append({
        'Model': model_type.title(),
        'Accuracy': summary.get('accuracy_mean', np.nan),
        'Precision': summary.get('precision_mean', np.nan),
        'Recall': summary.get('recall_mean', np.nan),
        'F1': summary.get('f1_mean', np.nan),
        'ROC AUC': summary.get('roc_auc_mean', np.nan)
    })

comparison_df = pd.DataFrame(comparison_data)
print("\nModel Comparison (Cross-Validation Scores):")
print(comparison_df.round(3))

In [None]:
# Visualize model comparison
metrics_to_plot = ['Accuracy', 'Precision', 'Recall', 'F1']
n_metrics = len(metrics_to_plot)

fig, axes = plt.subplots(1, n_metrics, figsize=(5*n_metrics, 6))

for i, metric in enumerate(metrics_to_plot):
    comparison_df.plot(x='Model', y=metric, kind='bar', ax=axes[i], 
                      color=['skyblue', 'lightcoral', 'lightgreen'])
    axes[i].set_title(f'{metric} Comparison')
    axes[i].set_ylabel(metric)
    axes[i].set_ylim(0, 1)
    axes[i].tick_params(axis='x', rotation=45)
    axes[i].legend().remove()

plt.tight_layout()
plt.show()

## 6. Final Evaluation and Insights

Let's create a final evaluation using a temporal split and analyze our results.

In [None]:
# Create temporal train-test split
train_df, test_df = validator.temporal_validation_split(
    combined_df, time_column='time', test_ratio=0.3
)

# Select best model based on cross-validation F1 score
best_model_type = comparison_df.loc[comparison_df['F1'].idxmax(), 'Model'].lower().replace(' ', '_')
best_model = trained_models[best_model_type]

print(f"Best model: {best_model_type}")
print(f"Best F1 score: {comparison_df['F1'].max():.3f}")

In [None]:
# Retrain best model on training data
print("Retraining best model on temporal training split...")
retrain_results = best_model.train(train_df, target_column='has_transit')

# Make predictions on test set
test_predictions, test_probabilities = best_model.predict(test_df)

# Evaluate final performance
final_metrics = validator.evaluate_predictions(
    test_df['has_transit'].values, 
    test_predictions, 
    test_probabilities
)

validator.print_evaluation_report(final_metrics)

In [None]:
# Visualize final results
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. Test set predictions over time
axes[0, 0].scatter(test_df['time'], test_df['flux'], c='blue', alpha=0.6, s=1, label='Normal')
transit_mask = test_predictions == 1
if np.any(transit_mask):
    axes[0, 0].scatter(test_df.iloc[transit_mask]['time'], 
                      test_df.iloc[transit_mask]['flux'], 
                      c='red', s=10, label='Predicted Transit')
axes[0, 0].set_title('Final Model Predictions on Test Set')
axes[0, 0].set_xlabel('Time (days)')
axes[0, 0].set_ylabel('Normalized Flux')
axes[0, 0].legend()

# 2. Prediction probabilities histogram
axes[0, 1].hist(test_probabilities, bins=50, alpha=0.7, edgecolor='black')
axes[0, 1].axvline(x=0.5, color='red', linestyle='--', label='Decision Threshold')
axes[0, 1].set_title('Distribution of Prediction Probabilities')
axes[0, 1].set_xlabel('Transit Probability')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].legend()

# 3. True vs predicted scatter plot
axes[1, 0].scatter(test_df['has_transit'], test_probabilities, alpha=0.6)
axes[1, 0].set_xlabel('True Labels')
axes[1, 0].set_ylabel('Predicted Probabilities')
axes[1, 0].set_title('True Labels vs Predicted Probabilities')
axes[1, 0].set_xticks([0, 1])
axes[1, 0].set_xticklabels(['No Transit', 'Transit'])

# 4. Performance metrics bar chart
metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1']
metrics_values = [final_metrics['accuracy'], final_metrics['precision'], 
                 final_metrics['recall'], final_metrics['f1']]
axes[1, 1].bar(metrics_names, metrics_values, color=['skyblue', 'lightcoral', 'lightgreen', 'gold'])
axes[1, 1].set_title('Final Model Performance')
axes[1, 1].set_ylabel('Score')
axes[1, 1].set_ylim(0, 1)

plt.tight_layout()
plt.show()

## 7. Summary and Conclusions

This notebook demonstrated a complete machine learning pipeline for exoplanet detection from transit data:

### Key Steps:
1. **Data Loading**: Used lightkurve to download real TESS light curves
2. **Preprocessing**: Applied detrending, outlier removal, and normalization
3. **Feature Engineering**: Created time-based features and transit labels
4. **Model Training**: Trained multiple ML models (Random Forest, Gradient Boosting, Logistic Regression)
5. **Validation**: Used time series cross-validation to properly evaluate models
6. **Final Evaluation**: Tested best model on held-out temporal test set

### Key Insights:
- Machine learning can successfully identify transit-like features in light curves
- Time series cross-validation is crucial for proper evaluation
- Feature engineering with temporal components improves performance
- Tree-based models perform well for this type of structured data

### Future Improvements:
- Use larger datasets with more diverse targets
- Implement more sophisticated transit detection algorithms
- Add physics-based features (period detection, transit duration, etc.)
- Experiment with deep learning approaches
- Incorporate additional data sources (stellar parameters, etc.)

This framework provides a solid foundation for developing more advanced exoplanet detection systems using machine learning.

In [None]:
# Save the best model for future use
model_save_path = '../models/best_transit_classifier.joblib'
Path('../models').mkdir(exist_ok=True)

best_model.save_model(model_save_path)
print(f"Best model saved to {model_save_path}")

# Save final results
results_df = pd.DataFrame([final_metrics])
results_df.to_csv('../results/final_evaluation.csv', index=False)
print("Final evaluation results saved to ../results/final_evaluation.csv")