<a href="https://colab.research.google.com/github/nanpolend/machine-learning/blob/master/%E5%88%A9%E7%94%A8kaggle%E8%B3%87%E6%96%99ai%E9%A0%90%E6%B8%AC%E9%BB%83%E9%87%91%E8%B5%B0%E5%8B%A2test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
+---------------------+      +---------------------+      +---------------------+      +---------------------+
|  數據獲取模組        |---->|  特徵工程模組        |---->|  模型訓練模組        |---->|  價格預測模組        |
|  GoldDataFetcher  |      |  FeatureEngineer   |      |  GoldPricePredictor |      |  (predict方法)     |
+---------------------+      +---------------------+      +---------------------+      +---------------------+
                                 ^                                      ^
                                 |                                      |
                                 +--------------------------------------+
                                               |
                                               v
                                        +-------------+
                                        |  數據輸入   |
                                        | (歷史/實時)|
                                        +-------------+

In [None]:
!pip install --upgrade numpy pandas-ta



In [6]:
# -*- coding: utf-8 -*-
# 安裝必要套件：pip install requests pandas numpy pandas-ta scikit-learn xgboost shap tensorflow matplotlib

import requests
import pandas as pd
import numpy as np
import pandas_ta as ta
from sklearn.preprocessing import RobustScaler
from sklearn.model_selection import TimeSeriesSplit, GridSearchCV
from xgboost import XGBRegressor
import shap
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import LSTM, Dense, Dropout
import warnings
warnings.filterwarnings('ignore')

# ====================
# 配置參數
# ====================
API_KEY = "YOUR_API_KEY"  # 務必替換有效API金鑰
SYMBOL = "XAUUSD"
LOOKBACK_DAYS = 30
TRAIN_DAYS = 1000

# ====================
# 數據獲取模組
# ====================
class GoldDataFetcher:
    @staticmethod
    def fetch_realtime():
        """獲取實時黃金價格"""
        url = "https://api.alltick.co/v1/quote"
        params = {"symbol": SYMBOL, "apikey": API_KEY}
        try:
            response = requests.get(url, params=params, timeout=10)
            if response.status_code == 200:
                data = response.json()
                return pd.DataFrame([{
                    "timestamp": pd.to_datetime(data['Timestamp'], unit='ms'),
                    "open": float(data['Open']),
                    "high": float(data['High']),
                    "low": float(data['Low']),
                    "close": float(data['Close']),
                    "volume": float(data['Volume'])
                }])
            else:
                print(f"實時API請求失敗，狀態碼: {response.status_code}")
                return None
        except Exception as e:
            print(f"實時數據獲取異常: {str(e)}")
            return None

    @staticmethod
    def fetch_historical(days=TRAIN_DAYS):
        """獲取歷史數據"""
        url = "https://api.alltick.co/v1/history"
        params = {
            "symbol": SYMBOL,
            "interval": "1d",
            "apikey": API_KEY,
            "limit": days
        }
        try:
            response = requests.get(url, params=params, timeout=15)
            if response.status_code == 200:
                raw_data = response.json()
                if 'data' not in raw_data or len(raw_data['data']) == 0:
                    print("歷史數據格式錯誤")
                    return None

                df = pd.DataFrame(raw_data['data'])
                df = df.rename(columns={
                    "Timestamp": "timestamp",
                    "Open": "open",
                    "High": "high",
                    "Low": "low",
                    "Close": "close",
                    "Volume": "volume"
                })
                df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
                df.set_index('timestamp', inplace=True)
                df = df.astype(float)

                # 技術指標計算
                df.ta.rsi(length=14, append=True)
                df.ta.macd(fast=12, slow=26, signal=9, append=True)
                df.ta.bbands(length=20, append=True)
                df.ta.ema(length=50, append=True)
                df = df.dropna()
                return df
            else:
                print(f"歷史API請求失敗，狀態碼: {response.status_code}")
                return None
        except Exception as e:
            print(f"歷史數據獲取異常: {str(e)}")
            return None

# ====================
# 特徵工程模組
# ====================
class FeatureEngineer:
    @staticmethod
    def add_features(df):
        """添加自定義因子"""
        required_columns = ['MACD_12_26_9', 'MACDh_12_26_9', 'MACDs_12_26_9']
        for col in required_columns:
            if col not in df.columns:
                df[col] = 0

        df['price_change'] = df['close'].pct_change()
        df['volatility'] = df['high'] - df['low']
        df['bb_width'] = (df['BBU_20_2.0'] - df['BBL_20_2.0']) / df['BBM_20_2.0']
        df['momentum_5'] = df['close'].pct_change(5)
        df = df.dropna()
        return df

# ====================
# 機器學習模組
# ====================
class GoldPricePredictor:
    def __init__(self):
        self.scaler = RobustScaler()
        self.features = ['RSI_14', 'MACD_12_26_9', 'MACDh_12_26_9',
                        'BBU_20_2.0', 'BBL_20_2.0', 'EMA_50',
                        'price_change', 'volatility', 'bb_width', 'momentum_5']
        self.target = 'close'
        self.xgb_model = XGBRegressor(objective='reg:squarederror')
        self.lstm_model = self.build_lstm()

    def build_lstm(self):
        """構建LSTM模型"""
        model = Sequential()
        model.add(LSTM(64,
                      return_sequences=True,
                      input_shape=(LOOKBACK_DAYS, len(self.features))))
        model.add(Dropout(0.3))
        model.add(LSTM(32))
        model.add(Dropout(0.3))
        model.add(Dense(1))
        model.compile(optimizer='adam', loss='mse')
        return model

    def train_models(self, df):
        """訓練模型"""
        if len(df) < LOOKBACK_DAYS * 2:
            raise ValueError(f"需要至少 {LOOKBACK_DAYS*2} 天歷史數據")

        df = FeatureEngineer.add_features(df)
        scaled_data = self.scaler.fit_transform(df[self.features + [self.target]])

        # XGBoost訓練
        X, y = scaled_data[:, :-1], scaled_data[:, -1]
        tscv = TimeSeriesSplit(n_splits=3)
        param_grid = {
            'n_estimators': [100, 200],
            'learning_rate': [0.03, 0.05],
            'max_depth': [3, 5]
        }
        grid_search = GridSearchCV(
            self.xgb_model, param_grid,
            cv=tscv, scoring='neg_mean_squared_error', n_jobs=-1
        )
        grid_search.fit(X, y)
        self.xgb_model = grid_search.best_estimator_

        # LSTM訓練
        X_lstm, y_lstm = [], []
        for i in range(LOOKBACK_DAYS, len(scaled_data)):
            X_lstm.append(scaled_data[i-LOOKBACK_DAYS:i, :-1])
            y_lstm.append(scaled_data[i, -1])
        X_lstm = np.array(X_lstm)
        y_lstm = np.array(y_lstm)

        split_idx = int(len(X_lstm) * 0.8)
        self.lstm_model.fit(
            X_lstm[:split_idx], y_lstm[:split_idx],
            validation_data=(X_lstm[split_idx:], y_lstm[split_idx:]),
            epochs=30,
            batch_size=16,
            verbose=0
        )

    def predict(self, latest_data):
        """執行預測"""
        if len(latest_data) < LOOKBACK_DAYS:
            raise ValueError(f"需要至少 {LOOKBACK_DAYS} 天數據")

        processed_data = FeatureEngineer.add_features(latest_data)
        scaled_input = self.scaler.transform(
            processed_data[self.features + [self.target]]
        )

        xgb_pred = self.xgb_model.predict(scaled_input[:, :-1])
        lstm_input = scaled_input[-LOOKBACK_DAYS:, :-1].reshape(1, LOOKBACK_DAYS, -1)
        lstm_pred = self.lstm_model.predict(lstm_input)

        # 反標準化
        dummy_matrix = np.zeros((1, len(self.features)+1))
        dummy_matrix[0, :-1] = scaled_input[-1, :-1]

        dummy_matrix[0, -1] = xgb_pred[-1]
        xgb_price = self.scaler.inverse_transform(dummy_matrix)[0, -1]

        dummy_matrix[0, -1] = lstm_pred[0, 0]
        lstm_price = self.scaler.inverse_transform(dummy_matrix)[0, -1]

        return xgb_price, lstm_price

    def explain_features(self, data_sample):
        """SHAP解釋"""
        scaled_sample = self.scaler.transform(
            data_sample[self.features + [self.target]]
        )[:, :-1]

        explainer = shap.TreeExplainer(self.xgb_model)
        shap_values = explainer.shap_values(scaled_sample)
        shap.summary_plot(shap_values, self.features, show=False)
        plt.title('黃金價格影響因子權重分析')
        plt.tight_layout()
        plt.show()

# ====================
# 主程序
# ====================
if __name__ == "__main__":
    fetcher = GoldDataFetcher()
    predictor = GoldPricePredictor()

    print("正在獲取歷史數據...")
    historical_data = fetcher.fetch_historical()
    if historical_data is not None:
        try:
            print("開始訓練模型...")
            predictor.train_models(historical_data)
            print("模型訓練完成！\n")

            sample_data = historical_data[predictor.features].sample(50, random_state=42)
            predictor.explain_features(sample_data)

            print("獲取實時數據並預測...")
            realtime_data = fetcher.fetch_realtime()
            if realtime_data is not None:
                full_data = pd.concat([historical_data, realtime_data], axis=0)
                full_data = FeatureEngineer.add_features(full_data)

                if len(full_data) >= LOOKBACK_DAYS:
                    xgb_price, lstm_price = predictor.predict(full_data[-LOOKBACK_DAYS:])
                    print("\n=== 預測結果 ===")
                    print(f"XGBoost預測下個收盤價: ${xgb_price:.2f}")
                    print(f"LSTM預測下個收盤價:   ${lstm_price:.2f}")
                    print(f"最新實時價格:        ${realtime_data['close'].values[0]:.2f}")
                else:
                    print("數據不足無法預測")
            else:
                print("實時數據獲取失敗")
        except Exception as e:
            print(f"運行錯誤: {str(e)}")
    else:
        print("歷史數據獲取失敗")

正在獲取歷史數據...
歷史API請求失敗，狀態碼: 404
歷史數據獲取失敗
