# 02_tls_search — GP 去噪 + TLS 週期搜尋

Pipeline: raw LC → GP denoising → TLS search → produce (period, t0, duration)

Requirements:
- `pip install transitleastsquares` (TLS)
- `pip install celerite2` or `pip install starry_process` (GP, optional)
- `pip install lightkurve astropy` (for data loading and optional BLS)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import json
import warnings
warnings.filterwarnings('ignore')

# Import custom modules
import sys
sys.path.append('..')
from app.denoise.gp import denoise
from app.search.tls_runner import run_tls

## 1. Load Light Curve Data

Load from either:
- Synthetic data for testing
- Real TESS/Kepler data
- Preprocessed data from artifacts/

In [None]:
# Option 1: Generate synthetic light curve with transit
def generate_synthetic_lc(n_points=7000, period=13.7, depth=0.001, duration_hours=3, noise_level=3e-4, seed=42):
    """Generate synthetic light curve with transit signal."""
    rng = np.random.default_rng(seed)
    
    # Time array (days)
    t = np.linspace(0, 27.4, n_points)
    
    # Base flux
    flux = np.ones_like(t)
    
    # Add transit signal (simple box model)
    t0 = 2.0  # Initial transit time
    duration_days = duration_hours / 24.0
    
    for i in range(int(t[-1] / period) + 1):
        transit_center = t0 + i * period
        in_transit = np.abs(t - transit_center) < duration_days / 2
        flux[in_transit] -= depth
    
    # Add stellar variability (sinusoidal)
    flux += 0.0002 * np.sin(2 * np.pi * t / 5.5)  # Rotation signal
    
    # Add noise
    flux += rng.normal(0, noise_level, size=t.shape)
    
    return t, flux

# Generate synthetic data
t_synth, flux_synth = generate_synthetic_lc()
print(f"Generated synthetic LC: {len(t_synth)} points over {t_synth[-1]:.1f} days")

# Plot raw light curve
plt.figure(figsize=(12, 4))
plt.plot(t_synth, flux_synth, 'k.', markersize=0.5, alpha=0.5)
plt.xlabel('Time (days)')
plt.ylabel('Relative Flux')
plt.title('Raw Light Curve')
plt.show()

In [None]:
# Option 2: Load real data from artifacts (if available)
try:
    # Check for preprocessed data
    data_path = Path('../artifacts/preprocessed_lcs.npz')
    if data_path.exists():
        data = np.load(data_path)
        t_real = data['time'][0]  # First LC
        flux_real = data['flux'][0]
        print(f"Loaded real LC: {len(t_real)} points")
        
        # Use real data
        t, flux = t_real, flux_real
    else:
        print("No real data found, using synthetic")
        t, flux = t_synth, flux_synth
except Exception as e:
    print(f"Using synthetic data: {e}")
    t, flux = t_synth, flux_synth

## 2. GP Denoising

Apply Gaussian Process regression to remove stellar variability while preserving transit signals.

In [None]:
# Apply GP denoising
print("Applying GP denoising...")
residuals, trend = denoise(t, flux, backend='auto')

# Denoised flux
flux_denoised = 1 + residuals  # Add back normalized baseline

# Visualize denoising results
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)

# Original
axes[0].plot(t, flux, 'k.', markersize=0.5, alpha=0.5)
axes[0].plot(t, 1 + trend, 'r-', linewidth=2, label='GP trend')
axes[0].set_ylabel('Original Flux')
axes[0].legend()

# Denoised
axes[1].plot(t, flux_denoised, 'b.', markersize=0.5, alpha=0.5)
axes[1].axhline(1.0, color='gray', linestyle='--', alpha=0.5)
axes[1].set_ylabel('Denoised Flux')

# Residuals
axes[2].plot(t, residuals, 'g.', markersize=0.5, alpha=0.5)
axes[2].axhline(0, color='gray', linestyle='--', alpha=0.5)
axes[2].set_xlabel('Time (days)')
axes[2].set_ylabel('Residuals')

plt.suptitle('GP Denoising Results')
plt.tight_layout()
plt.show()

print(f"Residuals std: {np.std(residuals):.6f}")
print(f"Trend range: {np.ptp(trend):.6f}")

## 3. TLS Period Search

Apply Transit Least Squares to find periodic transit signals.

In [None]:
# Run TLS search
print("Running TLS search...")
tls_result = run_tls(t, flux_denoised)

# Extract key parameters
if tls_result and hasattr(tls_result, 'period'):
    period_tls = tls_result.period
    t0_tls = tls_result.T0
    duration_tls = tls_result.duration
    depth_tls = 1 - tls_result.depth
    sde_tls = tls_result.SDE
    
    print(f"\nTLS Results:")
    print(f"  Period: {period_tls:.4f} days")
    print(f"  T0: {t0_tls:.4f} days")
    print(f"  Duration: {duration_tls:.4f} days ({duration_tls*24:.2f} hours)")
    print(f"  Depth: {(1-depth_tls)*1e6:.1f} ppm")
    print(f"  SDE: {sde_tls:.2f}")
    
    # Plot TLS periodogram
    fig, axes = plt.subplots(2, 1, figsize=(12, 8))
    
    # Periodogram
    if hasattr(tls_result, 'periods') and hasattr(tls_result, 'power'):
        axes[0].plot(tls_result.periods, tls_result.power, 'b-')
        axes[0].axvline(period_tls, color='r', linestyle='--', label=f'Best period: {period_tls:.4f} d')
        axes[0].set_xlabel('Period (days)')
        axes[0].set_ylabel('SDE')
        axes[0].set_title('TLS Periodogram')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
    
    # Phase-folded light curve
    phase = (t - t0_tls) % period_tls / period_tls
    phase[phase > 0.5] -= 1.0
    
    axes[1].plot(phase, flux_denoised, 'k.', markersize=0.5, alpha=0.3)
    
    # Bin the phase-folded data
    n_bins = 100
    phase_bins = np.linspace(-0.5, 0.5, n_bins + 1)
    binned_flux = []
    binned_phase = []
    for i in range(n_bins):
        mask = (phase >= phase_bins[i]) & (phase < phase_bins[i+1])
        if np.any(mask):
            binned_flux.append(np.median(flux_denoised[mask]))
            binned_phase.append((phase_bins[i] + phase_bins[i+1]) / 2)
    
    axes[1].plot(binned_phase, binned_flux, 'ro-', markersize=3, linewidth=1)
    axes[1].axhline(1.0, color='gray', linestyle='--', alpha=0.5)
    axes[1].set_xlabel('Phase')
    axes[1].set_ylabel('Relative Flux')
    axes[1].set_title(f'Phase-folded LC (P={period_tls:.4f} d)')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("TLS search failed or no signal found")

## 4. Optional: BLS Cross-validation

Use Box Least Squares from Astropy for comparison and cross-validation.

In [None]:
# Optional BLS search for cross-validation
try:
    from astropy.timeseries import BoxLeastSquares
    from astropy import units as u
    
    print("\nRunning BLS for cross-validation...")
    
    # Create BLS model
    bls = BoxLeastSquares(t * u.day, flux_denoised)
    
    # Search periods from 0.5 to 20 days
    period_grid = np.exp(np.linspace(np.log(0.5), np.log(20), 1000))
    bls_result = bls.power(period_grid * u.day, 0.05, oversample=20)
    
    # Find best period
    best_idx = np.argmax(bls_result.power)
    period_bls = bls_result.period[best_idx].value
    t0_bls = bls_result.transit_time[best_idx].value
    duration_bls = bls_result.duration[best_idx].value
    depth_bls = bls_result.depth[best_idx]
    snr_bls = bls_result.depth_snr[best_idx]
    
    print(f"\nBLS Results:")
    print(f"  Period: {period_bls:.4f} days")
    print(f"  T0: {t0_bls:.4f} days")
    print(f"  Duration: {duration_bls:.4f} days ({duration_bls*24:.2f} hours)")
    print(f"  Depth: {depth_bls*1e6:.1f} ppm")
    print(f"  SNR: {snr_bls:.2f}")
    
    # Compare TLS vs BLS
    if 'period_tls' in locals():
        print(f"\n=== TLS vs BLS Comparison ===")
        print(f"Period difference: {abs(period_tls - period_bls):.4f} days")
        print(f"T0 difference: {abs(t0_tls - t0_bls):.4f} days")
        print(f"Duration difference: {abs(duration_tls - duration_bls)*24:.2f} hours")
    
    # Plot BLS periodogram
    plt.figure(figsize=(12, 4))
    plt.plot(bls_result.period, bls_result.power, 'g-', label='BLS')
    plt.axvline(period_bls, color='g', linestyle='--', alpha=0.7, label=f'BLS: {period_bls:.4f} d')
    if 'period_tls' in locals():
        plt.axvline(period_tls, color='r', linestyle='--', alpha=0.7, label=f'TLS: {period_tls:.4f} d')
    plt.xlabel('Period (days)')
    plt.ylabel('Power')
    plt.title('BLS Periodogram')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
except ImportError:
    print("Astropy not installed, skipping BLS cross-validation")
except Exception as e:
    print(f"BLS search failed: {e}")

## 5. Save Results

Save the transit parameters for use in CNN training and inference.

In [None]:
# Prepare results dictionary
results = {
    'tls': {},
    'bls': {},
    'metadata': {
        'n_points': len(t),
        'time_span': float(t[-1] - t[0]),
        'cadence': float(np.median(np.diff(t)) * 24 * 60),  # in minutes
    }
}

# Add TLS results
if 'period_tls' in locals():
    results['tls'] = {
        'period': float(period_tls),
        't0': float(t0_tls),
        'duration': float(duration_tls),
        'depth_ppm': float((1-depth_tls)*1e6),
        'sde': float(sde_tls)
    }

# Add BLS results
if 'period_bls' in locals():
    results['bls'] = {
        'period': float(period_bls),
        't0': float(t0_bls),
        'duration': float(duration_bls),
        'depth_ppm': float(depth_bls*1e6),
        'snr': float(snr_bls)
    }

# Save to JSON
output_dir = Path('../artifacts')
output_dir.mkdir(exist_ok=True)

output_file = output_dir / 'transit_search_results.json'
with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"\nResults saved to {output_file}")
print(json.dumps(results, indent=2))

# Also save denoised light curve
lc_file = output_dir / 'denoised_lc.npz'
np.savez_compressed(
    lc_file,
    time=t,
    flux_raw=flux,
    flux_denoised=flux_denoised,
    gp_trend=trend if 'trend' in locals() else None
)
print(f"Light curve saved to {lc_file}")

## 6. Summary

This notebook demonstrates the complete TLS search pipeline:
1. Light curve loading/generation
2. GP denoising to remove stellar variability
3. TLS period search for transit detection
4. Optional BLS cross-validation
5. Results saved for downstream CNN training

The transit parameters (period, t0, duration) will be used in the next step (03b_cnn_train.ipynb) to create global and local views for CNN training.