In [7]:
%load_ext autoreload
%autoreload 2
# 這樣你只要存檔 src 裡的 .py 檔，Notebook 會自動更新，不用重啟 Kernel

In [6]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# 引入你的 src 模組
from src.utils.data_loader import prepare_data
from src.engine.trainer import train_v11
from src.engine.evaluator import evaluate_model

# 設定繪圖風格
sns.set_style("whitegrid")
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'Microsoft JhengHei']  # 解決中文亂碼
plt.rcParams['axes.unicode_minus'] = False  #%%


In [None]:
# 硬體設定
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {DEVICE}")

# 資料與模型參數
DATASET_PATH = Path("../dataset/USD_TWD.csv")
HORIZON = 3  # 預測未來 3 天
LOOKBACK = 30  # 回看過去 30 天
NUM_EPOCHS = 100  # 訓練輪數
LR = 0.001  # 學習率
SEED = 42  # 固定種子

# 設定隨機性 (Reproducibility)
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

In [None]:
if not DATASET_PATH.exists():
    print(f"[Error]: Dataset not found at {DATASET_PATH}")
else:
    print("[Loading Data]")
    df = pd.read_csv(DATASET_PATH)

    # 呼叫 src 裡的函式
    train_loader, test_loader, scalers_raw, _, _, _, _ = prepare_data(
        df, lookback=LOOKBACK, horizon=HORIZON
    )

    # 檢查一下 Batch
    sample = next(iter(train_loader))
    print(f"[Loading Data] Data Loaded! Train Batches: {len(train_loader)}, Test Batches: {len(test_loader)}")
    print(f"[Loading Data] Input Shape: {sample['raw_input'].shape}, Target Shape: {sample['target'].shape}")

In [None]:
print(f"[Training] Starting Training for {NUM_EPOCHS} epochs...")

model = train_v11(
    train_loader=train_loader,
    test_loader=test_loader,
    device=DEVICE,
    horizon=HORIZON,
    num_epochs=NUM_EPOCHS,
    lr=LR
)

print("[Training] Training Completed.")

In [None]:
print("[Evaluation] Running Evaluation...")

evaluate_model(
    model=model,
    test_loader=test_loader,
    scaler=scalers_raw['target'],  # 記得只傳 target scaler
    device=DEVICE,
    horizon=HORIZON
)