# 02b - 从本地光曲线提取特征

**前置条件**: 必须先运行 `02a_download_lightcurves.ipynb` 下载光曲线

**目标**: 从已下载的光曲线文件中提取 BLS 特征

**优势**:
- ✅ 不需要网络连接
- ✅ 可以快速迭代不同特征
- ✅ 支持并行处理
- ✅ 内存高效（批处理）

## Cell 1: 安装依赖

In [None]:
import sys
if 'google.colab' in sys.modules:
    print("📦 Installing dependencies...")
    !pip install -q numpy==1.26.4 'scipy<1.13' pandas scikit-learn joblib tqdm
    !pip install -q astropy
    print("✅ Installation complete")
else:
    print("✅ Running locally")

## Cell 2: 导入和配置

In [None]:
import os
import sys
import json
import warnings
from pathlib import Path
from typing import Dict, Optional
from concurrent.futures import ProcessPoolExecutor, as_completed

import numpy as np
import pandas as pd
from tqdm import tqdm
import joblib
from astropy.timeseries import BoxLeastSquares

warnings.filterwarnings('ignore')
np.random.seed(42)

print("✅ Imports successful")
print(f"   NumPy: {np.__version__}")
print(f"   Pandas: {pd.__version__}")

## Cell 3: 配置路径

In [None]:
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("🌍 Running in Google Colab")
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    
    REPO_DIR = Path('/content/exoplanet-starter')
    os.chdir(str(REPO_DIR))
    BASE_DIR = REPO_DIR
    LIGHTCURVE_DIR = Path('/content/drive/MyDrive/spaceapps-lightcurves')
    CHECKPOINT_DIR = Path('/content/drive/MyDrive/spaceapps-checkpoints')
    MODEL_DIR = Path('/content/drive/MyDrive/spaceapps-models')
else:
    print("💻 Running locally")
    BASE_DIR = Path.cwd().parent if 'notebooks' in str(Path.cwd()) else Path.cwd()
    LIGHTCURVE_DIR = BASE_DIR / 'data' / 'lightcurves'
    CHECKPOINT_DIR = BASE_DIR / 'checkpoints'
    MODEL_DIR = BASE_DIR / 'models'

DATA_DIR = BASE_DIR / 'data'
MODEL_DIR.mkdir(parents=True, exist_ok=True)

print(f"\n✅ Paths configured:")
print(f"   Lightcurves: {LIGHTCURVE_DIR}")
print(f"   Checkpoints: {CHECKPOINT_DIR}")
print(f"   Models: {MODEL_DIR}")

# 验证光曲线目录
if not LIGHTCURVE_DIR.exists():
    raise FileNotFoundError(
        f"❌ Lightcurve directory not found: {LIGHTCURVE_DIR}\n"
        f"   Please run 02a_download_lightcurves.ipynb first!"
    )

lc_files = list(LIGHTCURVE_DIR.glob('*.pkl'))
print(f"\n📦 Found {len(lc_files):,} lightcurve files")

## Cell 4: 特征提取配置

In [None]:
# 特征提取配置
CONFIG = {
    'max_workers': 4,          # 并行进程数（建议 2-8）
    'batch_size': 100,         # 批处理大小
    'save_interval': 50,       # 每 N 个保存一次
    'bls_periods': 2000,       # BLS 周期搜索点数
    'bls_durations': 10,       # BLS 持续时间搜索点数
    'period_min': 0.5,         # 最小周期（天）
    'period_max': 15.0,        # 最大周期（天）
}

print("⚙️ Feature Extraction Configuration:")
for key, val in CONFIG.items():
    print(f"   {key}: {val}")

# 加载数据集元数据
dataset_path = DATA_DIR / 'supervised_dataset.csv'
samples_df = pd.read_csv(dataset_path)

if 'sample_id' not in samples_df.columns:
    samples_df['sample_id'] = [f"SAMPLE_{i:06d}" for i in range(len(samples_df))]

print(f"\n✅ Dataset loaded: {len(samples_df):,} samples")

## Cell 5: 特征提取函数

In [None]:
def extract_features_from_lightcurve(
    time: np.ndarray,
    flux: np.ndarray,
    period: float = 1.0,
    duration: float = 0.1,
    depth: float = 0.01
) -> Dict[str, float]:
    """从光曲线提取特征"""
    features = {}
    
    # 基础统计
    features['flux_mean'] = float(np.nanmean(flux))
    features['flux_std'] = float(np.nanstd(flux))
    features['flux_median'] = float(np.nanmedian(flux))
    features['flux_mad'] = float(np.nanmedian(np.abs(flux - np.nanmedian(flux))))
    features['n_points'] = len(time)
    features['time_span'] = float(time[-1] - time[0])
    
    # 输入参数
    features['input_period'] = float(period)
    features['input_duration'] = float(duration)
    features['input_depth'] = float(depth)
    
    # BLS 分析
    try:
        bls = BoxLeastSquares(time, flux)
        periods = np.linspace(
            CONFIG['period_min'], 
            CONFIG['period_max'], 
            CONFIG['bls_periods']
        )
        durations = np.linspace(0.05, 0.5, CONFIG['bls_durations'])
        bls_result = bls.power(periods, durations)
        
        best_idx = np.argmax(bls_result.power)
        features['bls_power'] = float(bls_result.power[best_idx])
        features['bls_period'] = float(bls_result.period[best_idx])
        features['bls_duration'] = float(bls_result.duration[best_idx])
        features['bls_depth'] = float(bls_result.depth[best_idx])
        features['bls_snr'] = float(bls_result.depth_snr[best_idx])
        
    except Exception:
        # BLS 失败则使用默认值
        features.update({
            'bls_power': 0.0,
            'bls_period': period,
            'bls_duration': duration,
            'bls_depth': depth,
            'bls_snr': 0.0
        })
    
    return features


def process_single_file(file_path: Path, metadata: pd.Series) -> Optional[Dict]:
    """
    处理单个光曲线文件
    
    Args:
        file_path: 光曲线 pickle 文件路径
        metadata: 样本元数据（来自 supervised_dataset.csv）
    
    Returns:
        特征字典，失败则返回 None
    """
    try:
        # 加载光曲线数据
        data = joblib.load(file_path)
        lc_collection = data['lc_collection']
        
        # 使用第一个扇区的光曲线
        lc = lc_collection[0].remove_nans().normalize()
        
        # 提取特征
        features = extract_features_from_lightcurve(
            lc.time.value,
            lc.flux.value,
            metadata.get('period', 1.0),
            metadata.get('duration', 0.1),
            metadata.get('depth', 0.01)
        )
        
        # 添加元数据
        features['sample_id'] = data['sample_id']
        features['tic_id'] = data['tic_id']
        features['label'] = metadata.get('label', 0)
        features['n_sectors'] = data['n_sectors']
        
        return features
        
    except Exception:
        return None


print("✅ Feature extraction functions defined")

## Cell 6: 批量特征提取（主要执行）

In [None]:
# 加载 checkpoint
checkpoint_path = CHECKPOINT_DIR / 'features_checkpoint.parquet'
if checkpoint_path.exists():
    features_df = pd.read_parquet(checkpoint_path)
    print(f"📂 Loaded checkpoint: {len(features_df)} features")
else:
    features_df = pd.DataFrame()
    print("🆕 Starting fresh")

# 确定待处理文件
if len(features_df) > 0:
    processed_ids = set(features_df['sample_id'])
    lc_files_todo = [f for f in lc_files if f.stem.split('_')[0] not in processed_ids]
else:
    lc_files_todo = lc_files

print(f"\n📊 Extraction Progress:")
print(f"   Total files: {len(lc_files):,}")
print(f"   Already processed: {len(lc_files) - len(lc_files_todo):,}")
print(f"   Remaining: {len(lc_files_todo):,}")

if len(lc_files_todo) == 0:
    print("\n✅ All features already extracted!")
else:
    print(f"\n🚀 Starting feature extraction")
    print(f"   Workers: {CONFIG['max_workers']}")
    print(f"   Estimated time: {len(lc_files_todo) * 2 / 60 / CONFIG['max_workers']:.1f} minutes")
    
    # 准备任务
    tasks = []
    for lc_file in lc_files_todo:
        sample_id = lc_file.stem.split('_TIC')[0]
        metadata = samples_df[samples_df['sample_id'] == sample_id].iloc[0]
        tasks.append((lc_file, metadata))
    
    new_features = []
    
    # 并行处理
    with ProcessPoolExecutor(max_workers=CONFIG['max_workers']) as executor:
        future_to_task = {
            executor.submit(process_single_file, lc_file, metadata): (lc_file, metadata)
            for lc_file, metadata in tasks
        }
        
        with tqdm(total=len(tasks), desc="Extracting") as pbar:
            for future in as_completed(future_to_task):
                result = future.result()
                if result is not None:
                    new_features.append(result)
                pbar.update(1)
                
                # 定期保存
                if len(new_features) % CONFIG['save_interval'] == 0 and len(new_features) > 0:
                    temp_df = pd.concat([features_df, pd.DataFrame(new_features)], ignore_index=True)
                    temp_df.to_parquet(checkpoint_path, index=False)
                    pbar.set_postfix({'saved': len(new_features)})
    
    # 最终保存
    if len(new_features) > 0:
        features_df = pd.concat([features_df, pd.DataFrame(new_features)], ignore_index=True)
        features_df.to_parquet(checkpoint_path, index=False)
        print(f"\n💾 Saved {len(features_df):,} features")

# 定义特征列
feature_cols = [col for col in features_df.columns 
                if col not in ['sample_id', 'tic_id', 'label', 'n_sectors']]

print(f"\n✅ Feature extraction complete")
print(f"   Total features: {len(features_df):,}")
print(f"   Feature columns: {len(feature_cols)}")
print(f"   Features: {feature_cols}")

## Cell 7: 数据质量检查

In [None]:
print("🔍 Data Quality Check:\n")

# 缺失值统计
print("📊 Missing values:")
missing = features_df[feature_cols].isnull().sum()
if missing.sum() == 0:
    print("   ✅ No missing values!")
else:
    for col, count in missing[missing > 0].items():
        pct = count / len(features_df) * 100
        print(f"   {col}: {count} ({pct:.1f}%)")

# 标签分布
print("\n📊 Label distribution:")
label_counts = features_df['label'].value_counts()
print(f"   Positive (1): {label_counts.get(1, 0):,} ({label_counts.get(1, 0)/len(features_df)*100:.1f}%)")
print(f"   Negative (0): {label_counts.get(0, 0):,} ({label_counts.get(0, 0)/len(features_df)*100:.1f}%)")

# 特征统计
print("\n📊 Feature statistics (sample):")
print(features_df[feature_cols].describe().iloc[:, :5])  # 显示前 5 个特征

# 异常值检查
print("\n🔍 Checking for infinities and extreme values...")
X = features_df[feature_cols].values
n_inf = np.isinf(X).sum()
n_extreme = (np.abs(X) > 1e10).sum()

if n_inf > 0:
    print(f"   ⚠️ Found {n_inf} infinity values")
if n_extreme > 0:
    print(f"   ⚠️ Found {n_extreme} extreme values (>1e10)")
if n_inf == 0 and n_extreme == 0:
    print("   ✅ No infinities or extreme values")

## Cell 8: 保存最终特征

In [None]:
from datetime import datetime

# 保存特征数据
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
features_path = MODEL_DIR / f'features_{timestamp}.parquet'
features_df.to_parquet(features_path, index=False)
print(f"✅ Features saved: {features_path}")

# 保存特征列定义
feature_cols_path = MODEL_DIR / f'feature_columns_{timestamp}.json'
with open(feature_cols_path, 'w') as f:
    json.dump(feature_cols, f, indent=2)
print(f"✅ Feature columns saved: {feature_cols_path}")

# 保存提取报告
report = {
    'timestamp': timestamp,
    'total_samples': len(features_df),
    'n_features': len(feature_cols),
    'feature_names': feature_cols,
    'label_distribution': {
        'positive': int(label_counts.get(1, 0)),
        'negative': int(label_counts.get(0, 0))
    },
    'config': CONFIG,
    'data_quality': {
        'missing_values': int(missing.sum()),
        'infinities': int(n_inf),
        'extreme_values': int(n_extreme)
    }
}

report_path = MODEL_DIR / f'extraction_report_{timestamp}.json'
with open(report_path, 'w') as f:
    json.dump(report, f, indent=2)
print(f"✅ Report saved: {report_path}")

print(f"\n🎉 All files saved to: {MODEL_DIR}")

## ✅ 特征提取完成！

### 下一步:
1. 运行 `03_injection_train.ipynb` 或 `03_injection_train_PRODUCTION.ipynb` 训练模型
2. 可以修改 Cell 4 的配置尝试不同的特征提取策略
3. 重新运行只需要 Cell 6，不需要重新下载光曲线

### 输出文件:
- **特征数据**: `models/features_{timestamp}.parquet`
- **特征列定义**: `models/feature_columns_{timestamp}.json`
- **提取报告**: `models/extraction_report_{timestamp}.json`