In [98]:
import yfinance as yf
import mplfinance as mpf
import matplotlib.pyplot as plt
import os
import pandas as pd

In [99]:
# 股票代碼
taiwan_stocks = ["2330.TW", "2302.TW", "2049.TW", "2305.TW", "2454.TW"]

# 輸出資料夾
output_dir = "./train_data/"
output_img = "./train_data/image/"
output_label = "./train_data/label/"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_img, exist_ok=True)
os.makedirs(output_label, exist_ok=True)

# 過去十年資料，加上四天來計算五日移動平均
start_date = "2019-12-26"  # 提前四天
end_date = "2024-01-01"
interval_days = 4  # 每5天截圖一次

In [100]:
for stock_id in taiwan_stocks:
    # 下載股票資料並移除缺失值
    ticker = yf.Ticker(stock_id)
    stock_data = ticker.history(start=start_date, end=end_date, interval="1d")
    stock_data = stock_data.apply(pd.to_numeric, errors='coerce').dropna().astype(float)

    # 計算五日均線
    stock_data['SMA_5'] = stock_data['Close'].rolling(window=5).mean()
    
    # 去除多出來的前四天資料，保留所需十年範圍
    stock_data = stock_data.loc["2014-01-01":]

    # 遍歷每隔10天的資料
    for i in range(0, len(stock_data), interval_days):
        idx = stock_data.index[i]
        plot_data = stock_data.loc[idx - pd.Timedelta(days=4):idx]  # 繪製該日期前的完整數據
        

        # 檢查是否有足夠的數據
        if len(plot_data) < 5:  # 確保5天數據充足
            continue

        # 繪製 K 線圖，不顯示 SMA
        fig, ax = mpf.plot(
            plot_data,
            type='candle',
            style='charles',
            ylabel="",
            volume=False,
            xrotation=0,
            returnfig=True
        )
        # 移除 X, Y 軸的刻度與標籤
        ax[0].set_xticks([])
        ax[0].set_yticks([])
        ax[0].set_xlabel("")
        ax[0].set_ylabel("")
        ax[0].set_title("")

        # 儲存圖像
        image_path = os.path.join(output_img, f"{stock_id}_{idx.date()}.png")
        fig.savefig(image_path, dpi=100, bbox_inches='tight', pad_inches=0)
        plt.close(fig)

        # 計算趨勢標註
        row = plot_data.iloc[-1]  # 取當前最後一天的數據
        trend_label = 1 if row['Close'] > row['SMA_5'] else 0
        label_name = f"{stock_id}_{idx.date()}.txt"
        label_path = os.path.join(output_label, label_name)

        # 儲存標註
        with open(label_path, "w") as label_file:
            label_file.write(str(trend_label))