# 03b_cnn_train — 1D-CNN Training with Global/Local Views

This notebook trains a 1D-CNN using BLS/TLS parameters and (time, flux) data.
The model uses both global and local views of phase-folded light curves.

Outputs:
- `artifacts/cnn1d.pt` - Model weights
- `artifacts/calibrator.joblib` - Probability calibration
- `reports/metrics_cnn.json` - Performance metrics
- `reports/calibration_cnn.png` - Calibration plot

In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, TensorDataset
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Import custom modules
import sys
sys.path.append('..')
from app.models.cnn1d import make_model
from app.data.fold import Item, LightCurveViewsDataset
from app.trainers.cnn1d_trainer import train
from app.calibration.calibrate import run_and_save

# Setup device
if torch.cuda.is_available():
    device = 'cuda'
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    device = 'mps'
    print("Using Apple MPS")
else:
    device = 'cpu'
    print("Using CPU")

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")

## 1. Load Data and Transit Parameters

Load the light curves and transit parameters from TLS/BLS search results.

In [None]:
# Load transit search results from previous step
transit_results_path = Path('../artifacts/transit_search_results.json')
denoised_lc_path = Path('../artifacts/denoised_lc.npz')

# Check if real data exists
use_real_data = transit_results_path.exists() and denoised_lc_path.exists()

if use_real_data:
    print("Loading real data from TLS search...")
    with open(transit_results_path, 'r') as f:
        transit_results = json.load(f)
    
    lc_data = np.load(denoised_lc_path)
    t_real = lc_data['time']
    flux_real = lc_data['flux_denoised']
    
    # Use TLS results if available, otherwise BLS
    if transit_results.get('tls'):
        period = transit_results['tls']['period']
        t0 = transit_results['tls']['t0']
        duration = transit_results['tls']['duration']
        print(f"Using TLS parameters: P={period:.4f}d, T0={t0:.4f}d, Dur={duration:.4f}d")
    elif transit_results.get('bls'):
        period = transit_results['bls']['period']
        t0 = transit_results['bls']['t0']
        duration = transit_results['bls']['duration']
        print(f"Using BLS parameters: P={period:.4f}d, T0={t0:.4f}d, Dur={duration:.4f}d")
    else:
        use_real_data = False
        print("No transit parameters found, using synthetic data")
else:
    print("No real data found, using synthetic data for demo")

In [None]:
# Generate synthetic dataset for training
def generate_synthetic_dataset(n_samples=1000, test_split=0.2, val_split=0.1):
    """Generate synthetic light curves with and without transits."""
    items = []
    
    for i in range(n_samples):
        # Alternate between planet and non-planet
        has_planet = i % 2 == 0
        
        if has_planet:
            # Generate transit signal
            period = np.random.uniform(1.5, 10.0)
            t0 = np.random.uniform(0, 2.0)
            duration = np.random.uniform(0.05, 0.15)  # In days
            depth = np.random.uniform(0.0005, 0.003)
            
            n_points = 4096
            t = np.linspace(0, period * 6, n_points)
            flux = np.ones_like(t)
            
            # Add transits
            phase = ((t - t0) / period) % 1.0
            in_transit = phase < (duration / period)
            flux[in_transit] -= depth
            
            # Add stellar variability
            flux += 0.0001 * np.sin(2 * np.pi * t / 3.5)
            
            label = 1
        else:
            # Generate non-transit signal
            period = np.random.uniform(1.5, 10.0)
            t0 = 0.0
            duration = 0.1
            
            n_points = 4096
            t = np.linspace(0, 15, n_points)
            flux = np.ones_like(t)
            
            # Add only stellar variability and spots
            flux += 0.0002 * np.sin(2 * np.pi * t / np.random.uniform(2, 8))
            flux += 0.0001 * np.sin(2 * np.pi * t / np.random.uniform(0.5, 2))
            
            label = 0
        
        # Add noise
        noise_level = np.random.uniform(3e-4, 8e-4)
        flux += np.random.default_rng(i).normal(0, noise_level, size=flux.shape)
        
        items.append(Item(
            time=t,
            flux=flux,
            period=period,
            t0=t0,
            duration=duration,
            label=label
        ))
    
    return items

# Generate or load dataset
if use_real_data:
    # For demo, create a small dataset using the real parameters
    items = []
    # Add the real light curve multiple times with slight variations
    for i in range(100):
        # Add noise variations
        flux_variation = flux_real + np.random.normal(0, 2e-4, size=flux_real.shape)
        items.append(Item(
            time=t_real,
            flux=flux_variation,
            period=period,
            t0=t0,
            duration=duration,
            label=1 if i < 50 else 0  # Half with planets, half without
        ))
    print(f"Created {len(items)} samples from real data")
else:
    # Generate synthetic dataset
    items = generate_synthetic_dataset(n_samples=500)
    print(f"Generated {len(items)} synthetic samples")

# Create dataset
dataset = LightCurveViewsDataset(items)
print(f"Dataset size: {len(dataset)}")
print(f"Positive samples: {sum(item.label for item in items)}")
print(f"Negative samples: {sum(1-item.label for item in items)}")

## 2. Create Train/Validation/Test Splits

In [None]:
# Split dataset
n_total = len(dataset)
n_train = int(n_total * 0.7)
n_val = int(n_total * 0.15)
n_test = n_total - n_train - n_val

train_ds, val_ds, test_ds = random_split(
    dataset, 
    [n_train, n_val, n_test],
    generator=torch.Generator().manual_seed(42)
)

print(f"Train: {len(train_ds)} samples")
print(f"Validation: {len(val_ds)} samples")
print(f"Test: {len(test_ds)} samples")

# Create data loaders
batch_size = 32 if device == 'cpu' else 64
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

## 3. Visualize Global and Local Views

In [None]:
# Visualize a sample
sample_idx = 0
global_view, local_view, label = dataset[sample_idx]

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Global view
axes[0].plot(global_view, 'b-', linewidth=0.5)
axes[0].set_title(f'Global View (Label: {label})')
axes[0].set_xlabel('Phase bin')
axes[0].set_ylabel('Normalized flux')
axes[0].grid(True, alpha=0.3)

# Local view
axes[1].plot(local_view, 'r-', linewidth=0.5)
axes[1].set_title('Local View (Transit region)')
axes[1].set_xlabel('Phase bin')
axes[1].set_ylabel('Normalized flux')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Global view shape: {global_view.shape}")
print(f"Local view shape: {local_view.shape}")
print(f"Label: {label} ({'Planet' if label == 1 else 'No planet'})")

## 4. Initialize and Train CNN Model

In [None]:
# Initialize model
model = make_model()
print(f"Model architecture:")
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

In [None]:
# Training parameters
training_config = {
    'batch_size': batch_size,
    'lr': 1e-3,
    'max_epochs': 50,
    'patience': 10,
    'min_delta': 1e-4,
    'device': device
}

print("Training configuration:")
for key, value in training_config.items():
    print(f"  {key}: {value}")

# Create output directories
artifacts_dir = Path('../artifacts')
artifacts_dir.mkdir(exist_ok=True)
reports_dir = Path('../reports')
reports_dir.mkdir(exist_ok=True)

# Train model
print("\nStarting training...")
metrics = train(
    model=model,
    train_ds=train_ds,
    val_ds=val_ds,
    device=device,
    batch_size=training_config['batch_size'],
    lr=training_config['lr'],
    max_epochs=training_config['max_epochs'],
    patience=training_config['patience'],
    workdir=str(artifacts_dir.parent)
)

print("\nTraining completed!")
print(f"Final validation loss: {metrics['val_loss'][-1]:.4f}")
print(f"Final validation accuracy: {metrics['val_acc'][-1]:.4f}")

## 5. Plot Training History

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss
axes[0].plot(metrics['train_loss'], label='Train Loss')
axes[0].plot(metrics['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(metrics['train_acc'], label='Train Acc')
axes[1].plot(metrics['val_acc'], label='Val Acc')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(reports_dir / 'training_history_cnn.png', dpi=100, bbox_inches='tight')
plt.show()

## 6. Evaluate on Test Set

In [None]:
# Load best model
model = make_model()
model.load_state_dict(torch.load(artifacts_dir / 'cnn1d.pt', map_location=device))
model.to(device)
model.eval()

# Evaluate on test set
test_preds = []
test_probs = []
test_labels = []

with torch.no_grad():
    for global_batch, local_batch, label_batch in test_loader:
        # Move to device
        global_batch = torch.tensor(global_batch, dtype=torch.float32).to(device)
        local_batch = torch.tensor(local_batch, dtype=torch.float32).to(device)
        
        # Forward pass
        logits = model(global_batch, local_batch).squeeze(1)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()
        
        # Store results
        test_probs.append(probs.cpu().numpy())
        test_preds.append(preds.cpu().numpy())
        test_labels.append(label_batch.numpy())

# Concatenate results
test_probs = np.concatenate(test_probs)
test_preds = np.concatenate(test_preds)
test_labels = np.concatenate(test_labels)

# Calculate metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

test_metrics = {
    'accuracy': accuracy_score(test_labels, test_preds),
    'precision': precision_score(test_labels, test_preds),
    'recall': recall_score(test_labels, test_preds),
    'f1': f1_score(test_labels, test_preds),
    'roc_auc': roc_auc_score(test_labels, test_probs)
}

print("Test Set Metrics:")
for metric, value in test_metrics.items():
    print(f"  {metric}: {value:.4f}")

## 7. Probability Calibration

In [None]:
# Perform probability calibration on validation set
val_probs = []
val_labels = []

with torch.no_grad():
    for global_batch, local_batch, label_batch in val_loader:
        global_batch = torch.tensor(global_batch, dtype=torch.float32).to(device)
        local_batch = torch.tensor(local_batch, dtype=torch.float32).to(device)
        
        logits = model(global_batch, local_batch).squeeze(1)
        probs = torch.sigmoid(logits)
        
        val_probs.append(probs.cpu().numpy())
        val_labels.append(label_batch.numpy())

val_probs = np.concatenate(val_probs)
val_labels = np.concatenate(val_labels)

# Run calibration
print("\nRunning probability calibration...")
cal_info = run_and_save(
    val_labels, 
    val_probs, 
    out_dir=str(artifacts_dir), 
    method='isotonic'
)

print(f"Calibration method: {cal_info['method']}")
print(f"ECE before: {cal_info['ece_before']:.4f}")
print(f"ECE after: {cal_info['ece_after']:.4f}")
print(f"Brier score before: {cal_info['brier_before']:.4f}")
print(f"Brier score after: {cal_info['brier_after']:.4f}")

## 8. ROC and PR Curves

In [None]:
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
import joblib

# Load calibrator
calibrator = joblib.load(artifacts_dir / 'calibrator.joblib')

# Calibrate test probabilities
test_probs_cal = calibrator.transform(test_probs.reshape(-1, 1)).ravel()

# ROC curve
fpr, tpr, _ = roc_curve(test_labels, test_probs_cal)
roc_auc = auc(fpr, tpr)

# PR curve
precision, recall, _ = precision_recall_curve(test_labels, test_probs_cal)
pr_auc = average_precision_score(test_labels, test_probs_cal)

# Plot curves
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# ROC
axes[0].plot(fpr, tpr, 'b-', linewidth=2, label=f'CNN (AUC = {roc_auc:.3f})')
axes[0].plot([0, 1], [0, 1], 'k--', alpha=0.3)
axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate')
axes[0].set_title('ROC Curve')
axes[0].legend(loc='lower right')
axes[0].grid(True, alpha=0.3)

# PR
axes[1].plot(recall, precision, 'r-', linewidth=2, label=f'CNN (AP = {pr_auc:.3f})')
axes[1].set_xlabel('Recall')
axes[1].set_ylabel('Precision')
axes[1].set_title('Precision-Recall Curve')
axes[1].legend(loc='lower left')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(reports_dir / 'curves_cnn.png', dpi=100, bbox_inches='tight')
plt.show()

## 9. Save Final Metrics

In [None]:
# Compile all metrics
final_metrics = {
    'model': 'CNN1D',
    'training': {
        'epochs': len(metrics['train_loss']),
        'batch_size': training_config['batch_size'],
        'learning_rate': training_config['lr'],
        'device': device,
        'total_params': total_params,
        'trainable_params': trainable_params
    },
    'dataset': {
        'total_samples': n_total,
        'train_samples': n_train,
        'val_samples': n_val,
        'test_samples': n_test
    },
    'performance': {
        'test_accuracy': float(test_metrics['accuracy']),
        'test_precision': float(test_metrics['precision']),
        'test_recall': float(test_metrics['recall']),
        'test_f1': float(test_metrics['f1']),
        'test_roc_auc': float(test_metrics['roc_auc']),
        'test_pr_auc': float(pr_auc),
        'val_loss_final': float(metrics['val_loss'][-1]),
        'val_acc_final': float(metrics['val_acc'][-1])
    },
    'calibration': {
        'method': cal_info['method'],
        'ece_before': float(cal_info['ece_before']),
        'ece_after': float(cal_info['ece_after']),
        'brier_before': float(cal_info['brier_before']),
        'brier_after': float(cal_info['brier_after'])
    }
}

# Save metrics to JSON
metrics_file = reports_dir / 'metrics_cnn.json'
with open(metrics_file, 'w') as f:
    json.dump(final_metrics, f, indent=2)

print(f"\nMetrics saved to {metrics_file}")
print("\nFinal CNN Performance Summary:")
print(f"  Accuracy: {test_metrics['accuracy']:.3f}")
print(f"  Precision: {test_metrics['precision']:.3f}")
print(f"  Recall: {test_metrics['recall']:.3f}")
print(f"  F1 Score: {test_metrics['f1']:.3f}")
print(f"  ROC-AUC: {test_metrics['roc_auc']:.3f}")
print(f"  PR-AUC: {pr_auc:.3f}")
print(f"  ECE (calibrated): {cal_info['ece_after']:.3f}")

## 10. Inference Speed Test

In [None]:
import time

# Test inference speed
model.eval()
batch_sizes = [1, 8, 32, 64]
latency_results = {}

for bs in batch_sizes:
    # Create dummy input
    dummy_global = torch.randn(bs, 256).to(device)
    dummy_local = torch.randn(bs, 64).to(device)
    
    # Warmup
    for _ in range(10):
        with torch.no_grad():
            _ = model(dummy_global, dummy_local)
    
    # Time inference
    if device == 'cuda':
        torch.cuda.synchronize()
    
    start = time.time()
    n_iterations = 100
    
    for _ in range(n_iterations):
        with torch.no_grad():
            _ = model(dummy_global, dummy_local)
    
    if device == 'cuda':
        torch.cuda.synchronize()
    
    end = time.time()
    
    avg_latency = (end - start) / n_iterations * 1000  # Convert to ms
    throughput = bs * n_iterations / (end - start)
    
    latency_results[bs] = {
        'latency_ms': avg_latency,
        'throughput_samples_per_sec': throughput
    }
    
    print(f"Batch size {bs:3d}: {avg_latency:.2f} ms/batch, {throughput:.1f} samples/sec")

# Add to final metrics
final_metrics['inference_speed'] = latency_results

# Save updated metrics
with open(metrics_file, 'w') as f:
    json.dump(final_metrics, f, indent=2)

## Summary

This notebook successfully trained a 1D-CNN model for exoplanet detection using:
- Global and local phase-folded views
- TLS/BLS transit parameters
- GPU acceleration when available
- Probability calibration for reliable confidence scores

Key outputs:
- ✅ Model weights: `artifacts/cnn1d.pt`
- ✅ Calibrator: `artifacts/calibrator.joblib`
- ✅ Metrics: `reports/metrics_cnn.json`
- ✅ Plots: `reports/training_history_cnn.png`, `reports/curves_cnn.png`, `reports/calibration_cnn.png`

The model is now ready for inference on new data (see notebook 04).