<a href="https://colab.research.google.com/github/nanpolend/machine-learning/blob/master/kaggle%E6%95%B8%E6%93%9A%E9%9B%86.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""Jane Street市场预测完整流程（增强稳定版）"""
import os
import sys
import zipfile
import hashlib
import numpy as np
from urllib.request import urlretrieve
from tqdm import tqdm

# ========== 全局配置 ==========
CONFIG = {
    "BASE_PATH": "./jane_street_data",      # 数据存储路径
    "MAX_RETRIES": 3,                       # 下载重试次数
    "FILE_CHECKSUMS": {                     # 文件校验码（示例值）
        "train.csv": "sha256:8a4a1749b3c2...",
        "features.csv": "sha256:7b3e2d1a9f4..."
    },
    "DATA_SOURCES": [                       # 数据源优先级列表
        {
            "name": "Kaggle官方源",
            "type": "kaggle",
            "competition": "jane-street-market-prediction",
            "auth_required": True
        },
        {
            "name": "Google云存储备份",
            "type": "url",
            "url": "https://storage.googleapis.com/kaggle-competitions-data/kaggle-v1/25960/compressed/jane-street-market-prediction.zip",
            "auth_required": False
        }
    ]
}

# ========== 环境初始化 ==========
def setup_environment():
    """智能环境配置"""
    os.makedirs(CONFIG["BASE_PATH"], exist_ok=True)

    # 自动安装核心依赖
    try:
        import pandas as pd
    except ImportError:
        print("正在安装Pandas...")
        os.system("pip install -q pandas")
        import pandas as pd

    # 检测CUDA环境
    try:
        import cudf
        print("✅ 检测到CUDA环境，启用GPU加速")
        return cudf
    except ImportError:
        print("⚠️  CUDA不可用，使用Pandas进行CPU处理")
        return pd

# ========== 增强型下载器 ==========
class DataDownloader:
    """支持多源下载与智能校验"""

    def __init__(self):
        self.df_engine = setup_environment()
        self.retry_count = 0

    class DownloadError(Exception):
        """自定义下载异常"""
        pass

    def _validate_file(self, file_path):
        """执行文件校验"""
        try:
            with open(file_path, 'rb') as f:
                file_hash = hashlib.sha256(f.read()).hexdigest()
            expected_hash = CONFIG["FILE_CHECKSUMS"].get(
                os.path.basename(file_path), "").split(":")[-1]
            return file_hash == expected_hash
        except Exception as e:
            print(f"校验失败: {str(e)}")
            return False

    def _download_with_progress(self, url, save_path):
        """带进度条的下载器"""
        try:
            def _report(block_num, block_size, total_size):
                if total_size > 0:
                    progress = min(1.0, block_num * block_size / total_size)
                    sys.stdout.write(f"\r下载进度: {progress:.1%}")
                    sys.stdout.flush()

            urlretrieve(url, save_path, _report)
            print("\n下载完成")
            return True
        except Exception as e:
            print(f"\n下载失败: {str(e)}")
            return False

    def _handle_kaggle(self):
        """Kaggle官方下载流程"""
        try:
            from kaggle.api.kaggle_api_extended import KaggleApi
            api = KaggleApi()
            api.authenticate()
            print("正在从Kaggle下载数据集...")
            api.competition_download_files(
                CONFIG["DATA_SOURCES"][0]["competition"],
                path=CONFIG["BASE_PATH"]
            )
            return True
        except Exception as e:
            print(f"Kaggle下载失败: {str(e)}")
            if "Could not find kaggle.json" in str(e):
                print("请配置Kaggle API密钥：https://github.com/Kaggle/kaggle-api#api-credentials")
            return False

    def _handle_http(self):
        """HTTP直连下载流程"""
        target = CONFIG["DATA_SOURCES"][1]
        print(f"正在从{target['name']}下载...")
        save_path = f"{CONFIG['BASE_PATH']}/dataset.zip"

        for i in range(CONFIG["MAX_RETRIES"]):
            if self._download_with_progress(target["url"], save_path):
                return True
            print(f"重试下载 ({i+1}/{CONFIG['MAX_RETRIES']})")
        return False

    def _unzip_files(self):
        """安全解压处理"""
        for f in os.listdir(CONFIG["BASE_PATH"]):
            if f.endswith(".zip"):
                zip_path = f"{CONFIG['BASE_PATH']}/{f}"
                try:
                    print(f"正在解压 {f}...")
                    with zipfile.ZipFile(zip_path) as zf:
                        zf.testzip()  # 验证ZIP完整性
                        total = len(zf.infolist())
                        for file in tqdm(zf.infolist(), desc="解压进度"):
                            zf.extract(file, CONFIG["BASE_PATH"])
                    os.remove(zip_path)
                    return True
                except zipfile.BadZipFile:
                    print("⚠️  ZIP文件损坏，已删除")
                    os.remove(zip_path)
        return False

    def execute(self):
        """执行完整下载流程"""
        print("\n" + "="*40)
        print("开始数据获取流程".center(40))
        print("="*40)

        # 尝试Kaggle源
        if self._handle_kaggle() and self._unzip_files():
            print("✅ Kaggle数据下载成功")
            return True

        # 尝试HTTP源
        if self._handle_http() and self._unzip_files():
            print("✅ 备用数据下载成功")
            return True

        raise self.DownloadError("❌ 所有数据源均不可用")

# ========== 数据加载器 ==========
class DataLoader:
    """智能数据加载系统"""

    def __init__(self):
        self.downloader = DataDownloader()
        self.engine = self.downloader.df_engine

    def _load_real_data(self):
        """加载真实数据集"""
        try:
            train_path = f"{CONFIG['BASE_PATH']}/train.csv"
            if not os.path.exists(train_path):
                raise FileNotFoundError("未找到数据文件")

            if not self.downloader._validate_file(train_path):
                print("⚠️  文件校验未通过，数据可能不完整")

            print("正在加载数据文件...")
            df = self.engine.read_csv(train_path)

            # 列名规范化处理
            df.columns = [col.split('_')[0] for col in df.columns]
            return df
        except Exception as e:
            print(f"数据加载失败: {str(e)}")
            return None

    def _generate_mock_data(self):
        """生成模拟数据"""
        print("正在生成模拟数据...")
        np.random.seed(42)
        data = {
            'date': np.random.randint(0, 500, 10000),
            'weight': np.random.uniform(0, 1, 10000),
            'resp': np.random.normal(0, 0.01, 10000),
            **{f'feature_{i}': np.random.normal(0, 1, 10000) for i in range(130)},
            'ts_id': np.arange(10000)
        }
        return self.engine.DataFrame(data)

    def load(self):
        """执行加载流程"""
        try:
            self.downloader.execute()
            df = self._load_real_data()
            if df is not None:
                return df
        except DataDownloader.DownloadError as e:
            print(str(e))

        print("将使用模拟数据继续运行")
        return self._generate_mock_data()

# ========== 主执行流程 ==========
if __name__ == "__main__":
    try:
        # 初始化数据加载器
        loader = DataLoader()

        # 执行数据加载
        df = loader.load()

        # 数据预览
        print("\n" + "="*40)
        print("数据预览".center(40))
        print("="*40)
        print(f"数据集维度: {df.shape}")
        print("\n前3行数据:")
        print(df.head(3) if isinstance(df, pd.DataFrame) else df.head(3).to_pandas())

        # 数据保存
        output_path = f"{CONFIG['BASE_PATH']}/processed.parquet"
        df.to_parquet(output_path)
        print(f"\n处理后的数据已保存至: {output_path}")

    except Exception as e:
        print(f"\n❌ 程序执行异常: {str(e)}")
        sys.exit(1)
