# 02a - 批量下载光曲线数据

**目标**: 一次性下载所有 TOI 样本的 TESS 光曲线，保存为本地文件

**特点**:
- ✅ 批量并发下载（可配置线程数）
- ✅ 断点续传（中断后可继续）
- ✅ 自动重试失败样本
- ✅ 进度追踪和统计
- ✅ 保存为 FITS 格式（标准天文格式）

**预计时间**: 4-8 小时（取决于网络和样本数）

## Cell 1: 安装依赖

In [None]:
# Colab 环境检测和安装
import sys
if 'google.colab' in sys.modules:
    print("📦 Installing dependencies...")
    !pip install -q numpy==1.26.4 'scipy<1.13' lightkurve pandas tqdm joblib
    print("✅ Installation complete")
    print("⚠️ If errors, restart runtime and skip to Cell 2")
else:
    print("✅ Running locally")

## Cell 2: 导入库和配置

In [None]:
import os
import sys
import json
import time
import warnings
from pathlib import Path
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import pandas as pd
from tqdm import tqdm
import lightkurve as lk
import joblib

warnings.filterwarnings('ignore')

print("✅ Imports successful")
print(f"   Lightkurve: {lk.__version__}")
print(f"   NumPy: {np.__version__}")
print(f"   Pandas: {pd.__version__}")

## Cell 3: 配置路径

In [None]:
# 环境检测
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("🌍 Running in Google Colab")
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    
    REPO_DIR = Path('/content/exoplanet-starter')
    if not REPO_DIR.exists():
        print("📥 Cloning repository...")
        !git clone https://github.com/exoplanet-spaceapps/exoplanet-starter.git /content/exoplanet-starter
    
    os.chdir(str(REPO_DIR))
    BASE_DIR = REPO_DIR
    LIGHTCURVE_DIR = Path('/content/drive/MyDrive/spaceapps-lightcurves')
    CHECKPOINT_DIR = Path('/content/drive/MyDrive/spaceapps-checkpoints')
else:
    print("💻 Running locally")
    BASE_DIR = Path.cwd().parent if 'notebooks' in str(Path.cwd()) else Path.cwd()
    LIGHTCURVE_DIR = BASE_DIR / 'data' / 'lightcurves'
    CHECKPOINT_DIR = BASE_DIR / 'checkpoints'

DATA_DIR = BASE_DIR / 'data'
LIGHTCURVE_DIR.mkdir(parents=True, exist_ok=True)
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

print(f"\n✅ Paths configured:")
print(f"   Data: {DATA_DIR}")
print(f"   Lightcurves: {LIGHTCURVE_DIR}")
print(f"   Checkpoints: {CHECKPOINT_DIR}")

## Cell 4: 下载配置

In [None]:
# 下载配置
CONFIG = {
    'max_workers': 4,        # 并发线程数（建议 2-8）
    'max_retries': 3,        # 失败重试次数
    'timeout': 60,           # 单个下载超时（秒）
    'batch_size': 100,       # 批处理大小
    'save_interval': 50,     # 每 N 个保存一次进度
    'test_mode': False,      # True = 只下载前 100 个样本测试
}

print("⚙️ Download Configuration:")
for key, val in CONFIG.items():
    print(f"   {key}: {val}")

# 加载数据集
dataset_path = DATA_DIR / 'supervised_dataset.csv'
if not dataset_path.exists():
    raise FileNotFoundError(f"❌ Dataset not found: {dataset_path}")

samples_df = pd.read_csv(dataset_path)

# 测试模式
if CONFIG['test_mode']:
    samples_df = samples_df.head(100)
    print(f"\n⚠️ TEST MODE: Only processing {len(samples_df)} samples")

# 添加唯一 ID
if 'sample_id' not in samples_df.columns:
    samples_df['sample_id'] = [f"SAMPLE_{i:06d}" for i in range(len(samples_df))]

if 'tic_id' not in samples_df.columns:
    if 'tid' in samples_df.columns:
        samples_df['tic_id'] = samples_df['tid']
    elif 'target_id' in samples_df.columns:
        samples_df['tic_id'] = samples_df['target_id']

print(f"\n✅ Dataset loaded: {len(samples_df):,} samples")
print(f"   Positive: {samples_df['label'].sum():,}")
print(f"   Negative: {(~samples_df['label'].astype(bool)).sum():,}")

## Cell 5: 下载函数

In [None]:
def download_single_lightcurve(row: pd.Series, retries: int = 3) -> dict:
    """
    下载单个 TIC ID 的光曲线并保存为 pickle 文件
    
    Returns:
        dict: {'sample_id', 'tic_id', 'status', 'file_path', 'n_sectors', 'error'}
    """
    sample_id = row['sample_id']
    tic_id = int(float(row['tic_id']))
    
    result = {
        'sample_id': sample_id,
        'tic_id': tic_id,
        'status': 'failed',
        'file_path': None,
        'n_sectors': 0,
        'error': None
    }
    
    # 检查是否已下载
    file_path = LIGHTCURVE_DIR / f"{sample_id}_TIC{tic_id}.pkl"
    if file_path.exists():
        result['status'] = 'cached'
        result['file_path'] = str(file_path)
        return result
    
    # 尝试下载
    for attempt in range(retries):
        try:
            # 搜索光曲线
            search_result = lk.search_lightcurve(f"TIC {tic_id}", author='SPOC')
            
            if search_result is None or len(search_result) == 0:
                result['error'] = 'no_data_found'
                return result
            
            # 下载所有扇区
            lc_collection = search_result.download_all()
            
            if lc_collection is None or len(lc_collection) == 0:
                result['error'] = 'download_failed'
                return result
            
            # 保存为 pickle（包含完整的 LightCurveCollection）
            save_data = {
                'sample_id': sample_id,
                'tic_id': tic_id,
                'lc_collection': lc_collection,
                'n_sectors': len(lc_collection),
                'download_time': datetime.now().isoformat(),
                'sectors': [lc.meta.get('SECTOR', 'unknown') for lc in lc_collection]
            }
            
            joblib.dump(save_data, file_path)
            
            result['status'] = 'success'
            result['file_path'] = str(file_path)
            result['n_sectors'] = len(lc_collection)
            return result
            
        except Exception as e:
            result['error'] = str(e)
            if attempt < retries - 1:
                time.sleep(2 ** attempt)  # 指数退避
                continue
    
    return result


def load_checkpoint() -> pd.DataFrame:
    """加载下载进度 checkpoint"""
    checkpoint_path = CHECKPOINT_DIR / 'download_progress.parquet'
    if checkpoint_path.exists():
        df = pd.read_parquet(checkpoint_path)
        print(f"📂 Loaded checkpoint: {len(df)} downloads")
        return df
    return pd.DataFrame()


def save_checkpoint(progress_df: pd.DataFrame):
    """保存下载进度"""
    checkpoint_path = CHECKPOINT_DIR / 'download_progress.parquet'
    progress_df.to_parquet(checkpoint_path, index=False)
    print(f"💾 Checkpoint saved: {len(progress_df)} downloads")


print("✅ Download functions defined")

## Cell 6: 批量下载（主要执行）

In [None]:
# 加载进度
progress_df = load_checkpoint()

# 确定待下载样本
if len(progress_df) > 0:
    completed_ids = set(progress_df[progress_df['status'].isin(['success', 'cached'])]['sample_id'])
    remaining_samples = samples_df[~samples_df['sample_id'].isin(completed_ids)]
else:
    remaining_samples = samples_df.copy()

print(f"📊 Download Progress:")
print(f"   Total samples: {len(samples_df):,}")
print(f"   Already downloaded: {len(samples_df) - len(remaining_samples):,}")
print(f"   Remaining: {len(remaining_samples):,}")

if len(remaining_samples) == 0:
    print("\n✅ All lightcurves already downloaded!")
else:
    print(f"\n🚀 Starting download for {len(remaining_samples):,} samples")
    print(f"   Workers: {CONFIG['max_workers']}")
    print(f"   Estimated time: {len(remaining_samples) * 5 / 3600 / CONFIG['max_workers']:.1f} hours")
    print(f"   (assuming 5 sec/sample with {CONFIG['max_workers']} workers)")
    
    start_time = time.time()
    results = []
    
    # 并发下载
    with ThreadPoolExecutor(max_workers=CONFIG['max_workers']) as executor:
        # 提交所有任务
        future_to_row = {
            executor.submit(download_single_lightcurve, row, CONFIG['max_retries']): row 
            for _, row in remaining_samples.iterrows()
        }
        
        # 进度条
        with tqdm(total=len(remaining_samples), desc="Downloading") as pbar:
            for future in as_completed(future_to_row):
                result = future.result()
                results.append(result)
                pbar.update(1)
                
                # 定期保存
                if len(results) % CONFIG['save_interval'] == 0:
                    temp_df = pd.concat([progress_df, pd.DataFrame(results)], ignore_index=True)
                    save_checkpoint(temp_df)
                    
                    # 显示统计
                    success_count = sum(1 for r in results if r['status'] == 'success')
                    cached_count = sum(1 for r in results if r['status'] == 'cached')
                    failed_count = sum(1 for r in results if r['status'] == 'failed')
                    
                    pbar.set_postfix({
                        'success': success_count,
                        'cached': cached_count,
                        'failed': failed_count
                    })
    
    # 最终保存
    if len(results) > 0:
        progress_df = pd.concat([progress_df, pd.DataFrame(results)], ignore_index=True)
        save_checkpoint(progress_df)
    
    elapsed = time.time() - start_time
    
    print(f"\n🎉 Download complete!")
    print(f"   Total time: {elapsed / 3600:.2f} hours")
    print(f"   Average: {elapsed / len(results):.1f} sec/sample")

# 最终统计
print(f"\n📊 Final Statistics:")
status_counts = progress_df['status'].value_counts()
for status, count in status_counts.items():
    print(f"   {status}: {count:,}")

success_rate = (status_counts.get('success', 0) + status_counts.get('cached', 0)) / len(progress_df) * 100
print(f"\n   Success rate: {success_rate:.1f}%")
print(f"   Total lightcurves: {len(list(LIGHTCURVE_DIR.glob('*.pkl'))):,}")

## Cell 7: 验证下载数据

In [None]:
# 验证随机样本
print("🔍 Verifying downloaded data...\n")

pkl_files = list(LIGHTCURVE_DIR.glob('*.pkl'))
if len(pkl_files) == 0:
    print("❌ No lightcurve files found!")
else:
    # 随机选择 3 个文件验证
    sample_files = np.random.choice(pkl_files, min(3, len(pkl_files)), replace=False)
    
    for pkl_file in sample_files:
        try:
            data = joblib.load(pkl_file)
            print(f"✅ {pkl_file.name}")
            print(f"   TIC ID: {data['tic_id']}")
            print(f"   Sectors: {data['n_sectors']} ({data['sectors']})")
            print(f"   Downloaded: {data['download_time']}")
            
            # 检查第一个光曲线
            lc = data['lc_collection'][0]
            print(f"   Data points: {len(lc.time):,}")
            print(f"   Time span: {float(lc.time[-1] - lc.time[0]):.1f} days")
            print()
        except Exception as e:
            print(f"❌ {pkl_file.name}: {e}\n")
    
    print(f"\n📦 Storage:")
    total_size = sum(f.stat().st_size for f in pkl_files) / 1024 / 1024 / 1024
    print(f"   Total files: {len(pkl_files):,}")
    print(f"   Total size: {total_size:.2f} GB")
    print(f"   Average size: {total_size * 1024 / len(pkl_files):.1f} MB/file")

## Cell 8: 生成下载报告

In [None]:
# 生成详细报告
report = {
    'timestamp': datetime.now().isoformat(),
    'total_samples': len(samples_df),
    'downloaded': int(status_counts.get('success', 0) + status_counts.get('cached', 0)),
    'failed': int(status_counts.get('failed', 0)),
    'success_rate': float(success_rate),
    'config': CONFIG,
    'storage': {
        'directory': str(LIGHTCURVE_DIR),
        'total_files': len(pkl_files),
        'total_size_gb': float(total_size)
    },
    'errors': progress_df[progress_df['status'] == 'failed']['error'].value_counts().to_dict()
}

report_path = CHECKPOINT_DIR / 'download_report.json'
with open(report_path, 'w') as f:
    json.dump(report, f, indent=2)

print(f"✅ Download report saved: {report_path}")
print(f"\n📋 Summary:")
print(f"   Downloaded: {report['downloaded']:,} / {report['total_samples']:,}")
print(f"   Success rate: {report['success_rate']:.1f}%")
print(f"   Storage: {report['storage']['total_size_gb']:.2f} GB")

if report['errors']:
    print(f"\n⚠️ Error breakdown:")
    for error, count in report['errors'].items():
        print(f"   {error}: {count}")

## ✅ 下载完成！

### 下一步:
1. 运行 `02b_extract_features.ipynb` 进行特征提取
2. 特征提取可以多次运行，不需要重新下载
3. 可以尝试不同的特征提取策略

### 文件位置:
- **光曲线数据**: `{LIGHTCURVE_DIR}`
- **下载进度**: `checkpoints/download_progress.parquet`
- **下载报告**: `checkpoints/download_report.json`