In [None]:
# Step 1: 環境セットアップ（Google Driveマウント＋ライブラリインポート）
try:
    from google.colab import drive
    drive.mount('/content/drive')
except ImportError:
    print("Drive mount skipped")
import math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, roc_auc_score, precision_recall_curve, auc, f1_score

# Step 2: ファイルパス定義
# TRAIN_PATH = "/content/drive/MyDrive/TSAT/BTC_5min/BTC_full_5min_Train.csv"
# VALID_PATH = "/content/drive/MyDrive/TSAT/BTC_5min/BTC_full_5min_Valid.csv"

TRAIN_PATH = "/home/nagumo/TSAT/BTC_5min/BTC_full_5min_Train.csv"
VALID_PATH = "/home/nagumo/TSAT/BTC_5min/BTC_full_5min_Valid.csv"

# Step 3: 再現性確保のためのシード固定
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Step 4: データ読み込み＆補間関数定義
def load_and_interpolate(path):
    df = pd.read_csv(path, parse_dates=['date']).set_index('date').sort_index()
    idx = pd.date_range(df.index.min(), df.index.max(), freq='5T')
    return df.reindex(idx).interpolate(method='time')

# Step 5: データ読み込み実行＋期間フィルタリング
df_train = load_and_interpolate(TRAIN_PATH)
df_valid = load_and_interpolate(VALID_PATH)
df_train = df_train[df_train.index >= '2020-01-01']
print("Train:", df_train.index.min(), "to", df_train.index.max())
print("Valid:", df_valid.index.min(), "to", df_valid.index.max())

# Step 5: ハードラベル生成 (前後 k=6)
def make_hard_labels(lows, highs, k=6):
    n = len(lows)
    lbl = np.zeros(n, dtype=int)
    for t in range(k, n-k):
        if lows[t] < lows[t-k:t].min() and lows[t] < lows[t+1:t+k+1].min():
            lbl[t] = 1
        elif highs[t] > highs[t-k:t].max() and highs[t] > highs[t+1:t+k+1].max():
            lbl[t] = 2
    return lbl

lows_train, highs_train = df_train['low'].values, df_train['high'].values
lows_val, highs_val     = df_valid['low'].values, df_valid['high'].values
hard_train = make_hard_labels(lows_train, highs_train, k=6)
hard_val   = make_hard_labels(lows_val, highs_val, k=6)
df_train['hard_label'] = hard_train
df_valid['hard_label'] = hard_val

# Step 6: スコアベースソフトラベリング関数
def compute_scores(lows, highs, k=6):
    n = len(lows)
    s_min = np.zeros(n)
    s_max = np.zeros(n)
    for t in range(k, n-k):
        # 極小度
        s_min[t] = 100*np.mean((lows[t-k:t] - lows[t]) / lows[t-k:t] +
                           (lows[t+1:t+k+1] - lows[t]) / lows[t+1:t+k+1])
        # 極大度
        s_max[t] = 100*np.mean((highs[t] - highs[t-k:t]) / highs[t-k:t] +
                           (highs[t] - highs[t+1:t+k+1]) / highs[t+1:t+k+1])
    return s_min, s_max

smin_train, smax_train = compute_scores(lows_train, highs_train, k=6)
smin_val, smax_val     = compute_scores(lows_val, highs_val, k=6)

def soft_label_from_scores(smin, smax):
    # baseline for Other = 1.0
    # logits = np.stack([np.ones_like(smin), smin, smax], axis=1)
    logits = np.stack([1.5-np.maximum(smin, smax), smin, smax], axis=1) # other をsmin-smax
    exp = np.exp(logits)
    return exp / exp.sum(axis=1, keepdims=True)

soft_train = soft_label_from_scores(smin_train, smax_train)
soft_val   = soft_label_from_scores(smin_val, smax_val)
df_train['soft_label'] = list(soft_train)
df_valid['soft_label'] = list(soft_val)

# ソフトラベリングの結果を確認

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# --- ソフトラベルを numpy array に展開 ---
soft_array = np.array(df_train['soft_label'].tolist())
soft_other = soft_array[:, 0]
soft_min   = soft_array[:, 1]
soft_max   = soft_array[:, 2]

# --- タイムインデックスと価格 ---
times       = df_train.index
low_series  = df_train['low']
high_series = df_train['high']
close_series = df_train['close']

# --- プロット範囲を 2021-01-01 ～ 2021-01-31 に限定 ---
start, end = pd.to_datetime("2021-11-01-12:00:00"), pd.to_datetime("2021-11-01-23:00:00")
mask = (times >= start) & (times <= end)

# --- 抽出 ---
times_f        = times[mask]
low_f          = low_series[mask]
high_f         = high_series[mask]
close_f       = close_series[mask]
soft_min_f     = soft_min[mask]
soft_max_f     = soft_max[mask]
soft_other_f   = soft_other[mask]

# --- プロット1: Low Price vs Soft Min Score ---
fig, ax1 = plt.subplots(figsize=(12, 4))
ax1.plot(times_f, low_f, label='Low Price', color='tab:blue')
ax2 = ax1.twinx()
ax2.plot(times_f, soft_min_f, label='Soft Min Score', color='tab:orange', alpha=0.7)
ax2.set_ylim(0,1)
ax1.set_title('2021-01 Low Price vs Soft Min Score')
ax1.set_ylabel('Low Price')
ax2.set_ylabel('Soft Min Score')
ax1.legend(loc='upper left')
ax2.legend(loc='upper right')
plt.tight_layout()
plt.show()

# --- プロット2: High Price vs Soft Max Score ---
fig, ax1 = plt.subplots(figsize=(12, 4))
ax1.plot(times_f, high_f, label='High Price', color='tab:blue')
ax2 = ax1.twinx()
ax2.plot(times_f, soft_max_f, label='Soft Max Score', color='tab:green', alpha=0.7)
ax2.set_ylim(0,1)
ax1.set_title('2021-01 High Price vs Soft Max Score')
ax1.set_ylabel('High Price')
ax2.set_ylabel('Soft Max Score')
ax1.legend(loc='upper left')
ax2.legend(loc='upper right')
plt.tight_layout()
plt.show()

# --- プロット3: Close vs Soft Other Score ---
fig, ax1 = plt.subplots(figsize=(12, 4))
ax1.plot(times_f, close_f, label='Close Price', color='tab:blue')
ax2 = ax1.twinx()
ax2.plot(times_f, soft_min_f, label='Soft Min Score', color='tab:orange', alpha=0.7)
ax2.plot(times_f, soft_max_f,label='Soft Max Score', color='tab:green', alpha=0.7)
ax2.plot(times_f, soft_other_f, label='Soft Other Score', color='tab:red', alpha=0.7)
ax2.set_ylim(0,1)
ax1.set_title('2021-01 Close Price vs Soft Other Score')
ax1.set_ylabel('Close Price')
ax2.set_ylabel('Soft Other Score')
ax1.legend(loc='upper left')
ax2.legend(loc='upper right')
plt.tight_layout()
plt.show()


# Step 7: シーケンス特徴量 & ラベル生成 (SEQ_LEN=288)
SEQ_LEN = 288
from sklearn.model_selection import train_test_split
import os
import numpy as np
from tqdm import tqdm # tqdmライブラリをインポート (通常はスクリプトの先頭(Step1等)で一度だけインポートします)

output_dir = "evaluation_results"
os.makedirs(output_dir, exist_ok=True)

def make_dataset(df, hard_lbl, soft_lbl, description="Creating dataset"):
    """
    DataFrameからシーケンス特徴量とラベルを生成します。
    処理の進捗はtqdmプログレスバーで表示されます。
a
    Args:
        df (pd.DataFrame): 入力DataFrame (価格データなどを含む)
        hard_lbl (np.array): ハードラベル配列
        soft_lbl (np.array): ソフトラベル配列
        description (str, optional): tqdmプログレスバーに表示する説明文. Defaults to "Creating dataset".

    Returns:
        tuple: X, y_hard, y_soft, y_reg, t_indices のnumpy配列を含むタプル
    """
    X, y_hard, y_soft, y_reg, t_indices = [], [], [], [], []
    
    # DataFrameの行数に基づいてイテレーション。tqdmで進捗を表示。
    # unit="sequences" は進捗バーの単位がシーケンスであることを示す。
    # leave=True (デフォルト) はループ終了後もプログレスバーを残す。
    for t in tqdm(range(SEQ_LEN-1, len(df)-6), desc=description, unit="sequences", leave=True):
        win = df.iloc[t-SEQ_LEN+1:t+1]
        
        # --- 特徴量生成 ---
        c0 = win['close'].iloc[-1] 
        v0 = win['volume'].iloc[-1]

        norm_close = win['close'] / c0
        norm_high = win['high'] / c0
        norm_low = win['low'] / c0
        
        if v0 != 0:
            normalized_volume = win['volume'] / v0
        else:
            normalized_volume = np.zeros_like(win['volume'], dtype=float) 
        
        feats = np.stack([norm_close,
                          norm_high,
                          norm_low,
                          normalized_volume], axis=1)
        X.append(feats)
        
        # --- ラベル生成 ---
        y_hard.append(hard_lbl[t])
        y_soft.append(soft_lbl[t])
        
        if t + 1 < len(df): 
            c_plus_1 = df['close'].iloc[t+1]
            if c0 != 0:
                reg_target = c_plus_1 / c0
            else:
                reg_target = 1.0 
        else:
            reg_target = 1.0
        y_reg.append(reg_target)
        
        t_indices.append(t) 
        
    return np.array(X, dtype=np.float32), \
           np.array(y_hard), \
           np.array(y_soft, dtype=np.float32), \
           np.array(y_reg, dtype=np.float32), \
           np.array(t_indices)

# --- データセット作成実行 ---
# make_dataset呼び出し時に、tqdmプログレスバー用の説明文を指定
print("--- Training data creation ---")
X_tr, y_tr_hard, y_tr_soft, y_tr_reg, t_tr_indices = make_dataset(
    df_train, hard_train, soft_train, description="Processing df_train"
)
print("\n--- Validation/Test data creation ---")
X_val_all, y_val_hard_all, y_val_soft_all, y_val_reg_all, t_val_all_indices = make_dataset(
    df_valid, hard_val, soft_val, description="Processing df_valid"
)
print("") # tqdmの表示とprint文の間にスペースを空ける

# --- 検証データとテストデータの分割 ---
if len(X_val_all) > 1:
    X_val, X_test, \
    y_val_hard, y_test_hard, \
    y_val_soft, y_test_soft, \
    y_val_reg, y_test_reg, \
    t_val_indices, t_test_indices = train_test_split(
        X_val_all, y_val_hard_all, y_val_soft_all, y_val_reg_all, t_val_all_indices,
        test_size=0.5, random_state=SEED, shuffle=True # SEEDは事前に定義されている前提
    )
else:
    print("Warning: Not enough data in df_valid to split into validation and test sets after make_dataset. Using all for validation and test if applicable.")
    X_val, y_val_hard, y_val_soft, y_val_reg, t_val_indices = X_val_all, y_val_hard_all, y_val_soft_all, y_val_reg_all, t_val_all_indices
    X_test, y_test_hard, y_test_soft, y_test_reg, t_test_indices = X_val_all, y_val_hard_all, y_val_soft_all, y_val_reg_all, t_val_all_indices

# --- データ形状の確認表示 ---
print(f"X_tr shape: {X_tr.shape}, y_tr_reg shape: {y_tr_reg.shape}")
print(f"X_val shape: {X_val.shape if len(X_val)>0 else 'N/A'}, y_val_reg shape: {y_val_reg.shape if len(X_val)>0 else 'N/A'}")
print(f"X_test shape: {X_test.shape if len(X_test)>0 else 'N/A'}, y_test_reg shape: {y_test_reg.shape if len(X_test)>0 else 'N/A'}")

# Step 8: Dataset & DataLoader
class BTCSeqDataset(Dataset):
    def __init__(self, X, y_hard, y_soft, y_reg, t_indices): # y_reg を追加
        self.X = torch.from_numpy(X)
        self.y_hard = torch.from_numpy(y_hard)
        self.y_soft = torch.from_numpy(y_soft)
        self.y_reg = torch.from_numpy(y_reg) # y_reg を torch tensor に変換
        self.t_indices = torch.from_numpy(t_indices)

    def __len__(self): return len(self.X)
    def __getitem__(self, i):
        return self.X[i], self.y_hard[i], self.y_soft[i], self.y_reg[i], self.t_indices[i] # y_reg を返す

BATCH = 64
dl_tr = DataLoader(BTCSeqDataset(X_tr, y_tr_hard, y_tr_soft, y_tr_reg, t_tr_indices), batch_size=BATCH, shuffle=True)

if len(X_val) > 0 :
    dl_val = DataLoader(BTCSeqDataset(X_val, y_val_hard, y_val_soft, y_val_reg, t_val_indices), batch_size=BATCH)
else:
    dl_val = None
    print("Validation DataLoader (dl_val) is not created due to empty X_val.")

if len(X_test) > 0:
    dl_test= DataLoader(BTCSeqDataset(X_test, y_test_hard, y_test_soft, y_test_reg, t_test_indices), batch_size=BATCH)
else:
    dl_test = None
    print("Test DataLoader (dl_test) is not created due to empty X_test.")
    
# Step 9: PositionalEncoding + TransformerClassifier with SoftLabel Loss
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000)/d_model))
        pe[:,0::2] = torch.sin(pos * div)
        pe[:,1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(1))
    def forward(self, x): return x + self.pe[:x.size(0)]

class TransformerClassifierSoftLabel(nn.Module):
    def __init__(self, d_model=120, nhead=3, num_layers=2, num_classes=3, lambda_reg=0.1):
        super().__init__()
        self.lambda_reg = lambda_reg
        self.proj = nn.Linear(4, d_model)
        self.pos_enc = PositionalEncoding(d_model)
        layer = nn.TransformerEncoderLayer(d_model, nhead=nhead, dim_feedforward=256)
        self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
        self.classifier = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, num_classes))
        self.regressor  = nn.Linear(d_model, 1)

    def forward(self, x):
        h = self.proj(x).permute(1,0,2)
        h = self.pos_enc(h)
        h = self.encoder(h)
        cls = h[-1]
        return self.classifier(cls), self.regressor(cls).squeeze(-1)

    def compute_loss(self, logits, y_hard, y_soft, reg, y_reg):
        # 分類損失: KLDiv between softmax(logits) & soft labels
        loss_cls = F.kl_div(F.log_softmax(logits, dim=-1), y_soft, reduction='batchmean')
        # 回帰損失: MSE
        loss_reg = F.mse_loss(reg, y_reg)
        loss = loss_cls + self.lambda_reg * loss_reg
        return loss, loss_cls.item(), loss_reg.item()

model = TransformerClassifierSoftLabel().to(device)
opt = Adam(model.parameters(), lr=1e-4)
sched = CosineAnnealingLR(opt, T_max=10)

# 指標出力のためのヘルパー関数 (修正版)

# --- 必要なライブラリ ---
import torch
import numpy as np
import os
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score, precision_recall_curve, auc, f1_score
from tqdm.auto import tqdm 
import pandas as pd

# --- Helper Function: 閾値ベースの3クラス予測ラベル生成 (変更なし) ---
def get_thresholded_predictions_final(probs_np, threshold_val):
    y_pred_thresh = np.zeros(len(probs_np), dtype=int) 
    prob_min_scores = probs_np[:, 1] 
    prob_max_scores = probs_np[:, 2] 

    for i in range(len(probs_np)):
        is_min_candidate = prob_min_scores[i] >= threshold_val
        is_max_candidate = prob_max_scores[i] >= threshold_val

        if is_min_candidate and is_max_candidate:
            if prob_min_scores[i] >= prob_max_scores[i]:
                y_pred_thresh[i] = 1 
            else:
                y_pred_thresh[i] = 2 
        elif is_min_candidate:
            y_pred_thresh[i] = 1 
        elif is_max_candidate:
            y_pred_thresh[i] = 2 
    return y_pred_thresh

# --- Helper Function: 将来変動率およびソフトラベル確率成分の統計計算 (修正) ---
def calculate_signal_statistics(y_predictions_at_threshold, 
                                original_indices_np, 
                                source_df, # 'close', 'low', 'high', 'soft_label' カラムを含む想定
                                class_index_to_analyze, # 分析対象の予測クラス (1 for Min, 2 for Max)
                                num_future_steps):
    """
    指定された予測クラスのサンプルについて、以下を計算して返します:
    1. 将来の最安値への価格変動率(%)の統計 (mean, median, q1, q3)
    2. 将来の最高値への価格変動率(%)の統計 (mean, median, q1, q3)
    3. ソフトラベルのMin成分の統計 (mean, median, q1, q3, std)
    4. ソフトラベルのMax成分の統計 (mean, median, q1, q3, std)
    5. そのクラスの総予測数
    """
    changes_to_low_list = []
    changes_to_high_list = []
    sl_min_component_scores = [] 
    sl_max_component_scores = [] 
    
    predicted_as_class_indices = np.where(y_predictions_at_threshold == class_index_to_analyze)[0]
    total_predictions_for_class = len(predicted_as_class_indices)

    stats_template = {'mean': np.nan, 'median': np.nan, 'q1': np.nan, 'q3': np.nan}
    sl_component_stats_template = {'mean': np.nan, 'median': np.nan, 'q1': np.nan, 'q3': np.nan, 'std': np.nan}

    if total_predictions_for_class == 0:
        return (stats_template, stats_template, 
                sl_component_stats_template, sl_component_stats_template, 
                total_predictions_for_class)

    has_soft_label_col = 'soft_label' in source_df.columns
    if not has_soft_label_col:
        # この警告は呼び出し側(perform_detailed_evaluation)でまとめて出す方が重複を避けられる
        # print(f"WARNING: 'soft_label' column not found in source_df. SL component stats will be N/A.")
        pass # calculate_signal_statistics内では警告せず、データ欠損として扱う
    
    for arr_idx in predicted_as_class_indices:
        original_df_idx = original_indices_np[arr_idx]
        if not (0 <= original_df_idx < len(source_df)): 
            changes_to_low_list.append(np.nan); changes_to_high_list.append(np.nan)
            sl_min_component_scores.append(np.nan); sl_max_component_scores.append(np.nan)
            continue 
            
        current_data_point = source_df.iloc[original_df_idx]
        current_close = current_data_point.get('close', np.nan) # .getで存在しない場合もケア
        
        current_sl_min_comp = np.nan
        current_sl_max_comp = np.nan
        if has_soft_label_col:
            soft_label_entry = current_data_point['soft_label']
            if isinstance(soft_label_entry, (list, tuple)) and len(soft_label_entry) == 3:
                try:
                    current_sl_min_comp = float(soft_label_entry[1])
                    current_sl_max_comp = float(soft_label_entry[2])
                except (ValueError, TypeError, IndexError):
                    # 要素が期待通りでない場合は NaN のまま
                    pass 
        sl_min_component_scores.append(current_sl_min_comp)
        sl_max_component_scores.append(current_sl_max_comp)

        future_start_idx = original_df_idx + 1
        future_end_idx = original_df_idx + 1 + num_future_steps
        if future_end_idx <= len(source_df):
            future_window_df = source_df.iloc[future_start_idx:future_end_idx]
            if future_window_df.empty: 
                changes_to_low_list.append(np.nan); changes_to_high_list.append(np.nan)
                continue
            
            future_low_price = future_window_df['low'].min() if not future_window_df['low'].empty else np.nan
            future_high_price = future_window_df['high'].max() if not future_window_df['high'].empty else np.nan
            
            if pd.isna(future_low_price) or pd.isna(current_close) or current_close == 0: changes_to_low_list.append(np.nan)
            else: changes_to_low_list.append(((future_low_price - current_close) / current_close) * 100)
            
            if pd.isna(future_high_price) or pd.isna(current_close) or current_close == 0: changes_to_high_list.append(np.nan)
            else: changes_to_high_list.append(((future_high_price - current_close) / current_close) * 100)
        else: 
            changes_to_low_list.append(np.nan); changes_to_high_list.append(np.nan)

    stats_low = stats_template.copy()
    changes_to_low_cleaned = [c for c in changes_to_low_list if not pd.isna(c)]
    if changes_to_low_cleaned:
        stats_low.update({'mean': np.mean(changes_to_low_cleaned), 'median': np.median(changes_to_low_cleaned),
                          'q1': np.percentile(changes_to_low_cleaned, 25), 'q3': np.percentile(changes_to_low_cleaned, 75)})
        
    stats_high = stats_template.copy()
    changes_to_high_cleaned = [c for c in changes_to_high_list if not pd.isna(c)]
    if changes_to_high_cleaned: 
        stats_high.update({'mean': np.mean(changes_to_high_cleaned), 'median': np.median(changes_to_high_cleaned),
                           'q1': np.percentile(changes_to_high_cleaned, 25), 'q3': np.percentile(changes_to_high_cleaned, 75)})

    sl_min_comp_stats = sl_component_stats_template.copy()
    sl_min_components_cleaned = [s for s in sl_min_component_scores if not pd.isna(s)]
    if sl_min_components_cleaned:
        sl_min_comp_stats.update({
            'mean': np.mean(sl_min_components_cleaned), 'median': np.median(sl_min_components_cleaned),
            'q1': np.percentile(sl_min_components_cleaned, 25), 'q3': np.percentile(sl_min_components_cleaned, 75),
            'std': np.std(sl_min_components_cleaned)
        })
    
    sl_max_comp_stats = sl_component_stats_template.copy()
    sl_max_components_cleaned = [s for s in sl_max_component_scores if not pd.isna(s)]
    if sl_max_components_cleaned:
        sl_max_comp_stats.update({
            'mean': np.mean(sl_max_components_cleaned), 'median': np.median(sl_max_components_cleaned),
            'q1': np.percentile(sl_max_components_cleaned, 25), 'q3': np.percentile(sl_max_components_cleaned, 75),
            'std': np.std(sl_max_components_cleaned)
        })
            
    return stats_low, stats_high, sl_min_comp_stats, sl_max_comp_stats, total_predictions_for_class

# --- Helper Function: 将来指標統計とSL成分統計をテーブル行としてフォーマット (修正) ---
def format_future_stats_table_row(pred_type_str, 
                                  stats_low_dict, stats_high_dict, 
                                  sl_min_comp_stats_dict, sl_max_comp_stats_dict, 
                                  total_preds):
    def f(val, precision=4): 
        return f"{val:.{precision}f}" if not pd.isna(val) else "N/A"
    
    # ヘッダーに合わせて列幅を調整 (perform_detailed_evaluation内のheader_cols定義と一致させる)
    # "Prediction Type".ljust(19)
    # "Fut_Low_Mean".rjust(12), "Fut_Low_Median".rjust(14), "Fut_Low_Q1".rjust(10), "Fut_Low_Q3".rjust(10)
    # "Fut_High_Mean".rjust(13), "Fut_High_Median".rjust(15), "Fut_High_Q1".rjust(11), "Fut_High_Q3".rjust(11)
    # "SLMinC_Mean".rjust(13), "SLMinC_Median".rjust(14), "SLMinC_Q1".rjust(10), "SLMinC_Q3".rjust(10), "SLMinC_Std".rjust(10)
    # "SLMaxC_Mean".rjust(13), "SLMaxC_Median".rjust(14), "SLMaxC_Q1".rjust(10), "SLMaxC_Q3".rjust(10), "SLMaxC_Std".rjust(10)
    # "N_Pred".rjust(7)
    
    row = f"| {pred_type_str:<17} | {f(stats_low_dict['mean']):>12} | {f(stats_low_dict['median']):>14} | {f(stats_low_dict['q1']):>10} | {f(stats_low_dict['q3']):>10} | "
    row += f"{f(stats_high_dict['mean']):>13} | {f(stats_high_dict['median']):>15} | {f(stats_high_dict['q1']):>11} | {f(stats_high_dict['q3']):>11} | "
    row += f"{f(sl_min_comp_stats_dict['mean']):>13} | {f(sl_min_comp_stats_dict['median']):>14} | {f(sl_min_comp_stats_dict['q1']):>10} | {f(sl_min_comp_stats_dict['q3']):>10} | {f(sl_min_comp_stats_dict['std']):>10} | "
    row += f"{f(sl_max_comp_stats_dict['mean']):>13} | {f(sl_max_comp_stats_dict['median']):>14} | {f(sl_max_comp_stats_dict['q1']):>10} | {f(sl_max_comp_stats_dict['q3']):>10} | {f(sl_max_comp_stats_dict['std']):>10} | "
    row += f"{str(total_preds):>7} |"
    return row

# --- Main Evaluation Function (修正版) ---
def perform_detailed_evaluation(model_instance, dataloader_to_eval, original_df_for_lookup,
                                device_to_use, output_path, file_prefix_str,
                                class_target_names, prob_thresholds, num_future_steps,
                                epoch_info_str="N/A"): # s_min/max_col_nameはcalculate_signal_statistics内でデフォルト使用
    print(f"\n--- Starting Detailed Evaluation for: {file_prefix_str} (Source Epoch for Model: {epoch_info_str}) ---")
    if dataloader_to_eval is None: print(f"ERROR: DataLoader for '{file_prefix_str}' is None."); return

    if 'soft_label' not in original_df_for_lookup.columns:
        print(f"WARNING: 'soft_label' column not found in 'original_df_for_lookup'. Statistics for soft label components will be N/A for all predictions.")
        # この場合でも処理は続行し、calculate_signal_statistics が NaN を適切に処理する

    model_instance.eval() 
    all_true_labels, all_probs, all_original_indices = [], [], []

    with torch.no_grad(): 
        eval_pbar = tqdm(dataloader_to_eval, desc=f"Evaluating {file_prefix_str}", leave=False)
        for Xb, yh_hard, _, _, t_idx in eval_pbar: 
            Xb = Xb.to(device_to_use)
            logits, _ = model_instance(Xb) 
            all_probs.append(torch.softmax(logits, dim=-1).cpu().numpy())
            all_true_labels.append(yh_hard.cpu().numpy())
            all_original_indices.append(t_idx.cpu().numpy())

    if not all_true_labels: print(f"ERROR: No data from DataLoader for '{file_prefix_str}'."); return
    y_true_np = np.concatenate(all_true_labels)
    probs_np = np.concatenate(all_probs)
    original_indices_np = np.concatenate(all_original_indices)
    if not (len(y_true_np) == len(probs_np) == len(original_indices_np)):
        print(f"ERROR: Mismatch in array lengths for {file_prefix_str}."); return

    ep_info_fn = str(epoch_info_str).replace(" ", "_").replace(":", "-")
    metrics_filepath = os.path.join(output_path, f"{file_prefix_str}_metrics_ep{ep_info_fn}.txt")
    returns_filepath = os.path.join(output_path, f"{file_prefix_str}_future_returns_ep{ep_info_fn}.txt")

    overall_auc_scores = {} 
    for class_idx, class_name in enumerate(class_target_names):
        if class_idx == 0: continue 
        overall_auc_scores[class_name] = {'roc_auc': "N/A", 'pr_auc': "N/A"}
        if np.any(y_true_np == class_idx) and np.any(y_true_np != class_idx):
            try:
                roc = roc_auc_score((y_true_np == class_idx).astype(int), probs_np[:, class_idx])
                prec, rec, _ = precision_recall_curve((y_true_np == class_idx).astype(int), probs_np[:, class_idx])
                pr = auc(rec, prec)
                overall_auc_scores[class_name]['roc_auc'] = f"{roc:.4f}"
                overall_auc_scores[class_name]['pr_auc'] = f"{pr:.4f}"
            except ValueError: pass

    with open(metrics_filepath, "w") as f_metrics, open(returns_filepath, "w") as f_returns:
        f_metrics.write(f"Detailed Metrics for {file_prefix_str} (Model from Epoch: {epoch_info_str})\n")
        f_returns.write(f"Future Returns Analysis for {file_prefix_str} (Model from Epoch: {epoch_info_str})\n")

        f_metrics.write("="*50 + "\nOverall Class-Specific AUCs (Threshold-Independent):\n")
        for cn_auc in [class_target_names[1], class_target_names[2]]:
             f_metrics.write(f"  {cn_auc} - ROC-AUC: {overall_auc_scores[cn_auc]['roc_auc']}, PR-AUC: {overall_auc_scores[cn_auc]['pr_auc']}\n")
        f_metrics.write("="*50 + "\n\n")
        
        header_cols = [ # ヘッダーの列名を調整、SLMinC/SLMaxC に変更
            "Prediction Type".ljust(19), 
            "Fut_Low_Mean".rjust(12), "Fut_Low_Median".rjust(14), "Fut_Low_Q1".rjust(10), "Fut_Low_Q3".rjust(10), 
            "Fut_High_Mean".rjust(13), "Fut_High_Median".rjust(15), "Fut_High_Q1".rjust(11), "Fut_High_Q3".rjust(11),
            "SLMinC_Mean".rjust(13), "SLMinC_Median".rjust(14), "SLMinC_Q1".rjust(10), "SLMinC_Q3".rjust(10), "SLMinC_Std".rjust(10), 
            "SLMaxC_Mean".rjust(13), "SLMaxC_Median".rjust(14), "SLMaxC_Q1".rjust(10), "SLMaxC_Q3".rjust(10), "SLMaxC_Std".rjust(10), 
            "N_Pred".rjust(7) 
        ]
        returns_table_header = "|"+ "".join([f"{col}|" for col in header_cols]) + "\n"
        returns_table_separator = "|" + "".join(["-"*len(col) + "|" for col in header_cols]) + "\n"
        
        for th in prob_thresholds: 
            f_metrics.write(f"--- Metrics for Threshold: {th:.2f} ---\n")
            # (CM, CR, F1スコアの書き込み - 変更なし)
            y_pred_at_threshold = get_thresholded_predictions_final(probs_np, th)
            cm = confusion_matrix(y_true_np, y_pred_at_threshold, labels=[0,1,2])
            f_metrics.write("Confusion Matrix:\n" + np.array2string(cm) + "\n\n")
            try:
                cr = classification_report(y_true_np, y_pred_at_threshold, target_names=class_target_names, labels=[0,1,2], zero_division=0, digits=4)
                f_metrics.write("Classification Report:\n" + cr + "\n\n")
            except Exception as e_cr:
                f_metrics.write(f"Could not generate CR for threshold {th:.2f}: {e_cr}\n\n")
            f1_min = f1_score((y_true_np == 1).astype(int), (probs_np[:, 1] > th).astype(int), zero_division=0)
            f1_max = f1_score((y_true_np == 2).astype(int), (probs_np[:, 2] > th).astype(int), zero_division=0)
            f_metrics.write(f"Binary F1-score (Min vs Rest) @{th:.2f}: {f1_min:.4f}\n")
            f_metrics.write(f"Binary F1-score (Max vs Rest) @{th:.2f}: {f1_max:.4f}\n")
            f_metrics.write("-" * 40 + "\n\n") 

            # --- f_returns への書き込み ---
            f_returns.write(f"--- Returns Analysis for Threshold: {th:.2f} ---\n")
            f_returns.write(returns_table_header)
            f_returns.write(returns_table_separator)
            
            # Min予測の統計
            stats_min_low, stats_min_high, sl_min_comp_stats_min, sl_max_comp_stats_min, total_min_preds = calculate_signal_statistics(
                y_pred_at_threshold, original_indices_np, original_df_for_lookup,
                1, num_future_steps # class_index 1 for Min
            )
            f_returns.write(format_future_stats_table_row(
                f"{class_target_names[1]} Predictions", stats_min_low, stats_min_high, 
                sl_min_comp_stats_min, sl_max_comp_stats_min, # 両方のSL成分統計を渡す
                total_min_preds) + "\n")

            # Max予測の統計
            stats_max_low, stats_max_high, sl_min_comp_stats_max, sl_max_comp_stats_max, total_max_preds = calculate_signal_statistics(
                y_pred_at_threshold, original_indices_np, original_df_for_lookup,
                2, num_future_steps # class_index 2 for Max
            )
            f_returns.write(format_future_stats_table_row(
                f"{class_target_names[2]} Predictions", stats_max_low, stats_max_high, 
                sl_min_comp_stats_max, sl_max_comp_stats_max, # 両方のSL成分統計を渡す
                total_max_preds) + "\n")
            
            # Up/Down カウントと比率を f_returns に書き込み (テーブルの後)
            f_returns.write(f"\nPrice Change Direction after {num_future_steps} steps (for signals at threshold {th:.2f}):\n")
            for class_idx_ud, class_name_ud in [(1, class_target_names[1]), (2, class_target_names[2])]:
                up_count = 0; down_count = 0; neutral_count = 0; valid_comparison_count = 0
                predicted_indices_for_class = np.where(y_pred_at_threshold == class_idx_ud)[0]

                if len(predicted_indices_for_class) > 0:
                    for arr_idx_ud in predicted_indices_for_class:
                        original_df_idx_ud = original_indices_np[arr_idx_ud]
                        if not (0 <= original_df_idx_ud < len(original_df_for_lookup) and \
                                original_df_idx_ud + num_future_steps < len(original_df_for_lookup)):
                            continue
                        current_close_ud = original_df_for_lookup['close'].iloc[original_df_idx_ud]
                        future_close_ud = original_df_for_lookup['close'].iloc[original_df_idx_ud + num_future_steps]
                        if pd.isna(current_close_ud) or pd.isna(future_close_ud): continue
                        valid_comparison_count += 1
                        if future_close_ud > current_close_ud: up_count += 1
                        elif future_close_ud < current_close_ud: down_count += 1
                        else: neutral_count += 1
                
                up_down_total_for_ratio = up_count + down_count
                up_pct_of_directional = f"{(up_count / up_down_total_for_ratio * 100):.2f}%" if up_down_total_for_ratio > 0 else "N/A"
                direct_up_down_ratio = "N/A"
                if down_count > 0: direct_up_down_ratio = f"{(up_count / down_count):.2f}"
                elif up_count > 0: direct_up_down_ratio = "Inf (all up)" # No downs, only ups
                
                f_returns.write(f"  For {class_name_ud} Pred (N_pred={len(predicted_indices_for_class)}, N_valid_comp={valid_comparison_count}): "
                                f"Up: {up_count}, Down: {down_count}, Neutral: {neutral_count} | "
                                f"Up Pct (of Up+Down): {up_pct_of_directional}, Up/Down Ratio: {direct_up_down_ratio}\n")
            f_returns.write("-" * 40 + "\n\n") # Threshold block in f_returns ends
    
    print(f"INFO: Detailed metrics for '{file_prefix_str}' saved to: {metrics_filepath}")
    print(f"INFO: Future returns analysis for '{file_prefix_str}' saved to: {returns_filepath}")
    print(f"--- Finished Detailed Evaluation for: {file_prefix_str} ---")
    
# Step 10: 学習ループ (tqdm + ライブプロット + 詳細評価 + ベストモデル保存)

# --- このStep10のコードブロックを実行するための前提条件 ---
# 以下のオブジェクト・変数は、これより前のStep (1-9) で定義・初期化されている必要があります:
#   model, opt, sched, dl_tr, dl_val, df_valid, device, output_dir
#   TransformerClassifierSoftLabel クラス定義
#   ヘルパー関数: get_thresholded_predictions_final, log_future_price_change_stats_revised
# 必要なライブラリ: torch, numpy as np, os, sklearn.metrics 各種, 
#                 IPython.display (display, clear_output), matplotlib.pyplot as plt, tqdm.auto
# -----------------------------------------------------------------

print(f"INFO: Step 10 - Initializing variables and starting training loop.")

# --- Step10で必要となるパラメータと状態変数の定義・初期化 ---
EPOCHS = 30  # 例: 総エポック数 (実際の値に合わせてください)
thresholds = [0.25, 0.5, 0.75] # 評価に使用する確率閾値のリスト
target_names = ['Other', 'Min', 'Max'] # クラスラベル名
FUTURE_STEPS = 6 # 将来の価格変動を何ステップ先まで見るか

# 学習履歴保存用
history = {'train_total':[], 'train_cls':[], 'train_reg':[],
           'val_total':[],   'val_cls':[],   'val_reg':[]}

# ベストモデル保存用
min_val_loss = float('inf')
# output_dir は Step7 で定義されている想定
best_model_path = os.path.join(output_dir, "best_transformer_model.pth") 
best_epoch = -1

# ライブプロット用の図と軸を準備 (再実行時のために存在確認と初期化)
# この fig_loss, ax_loss はこのStep10のブロック内で閉じることを推奨 (plt.close(fig_loss))
# または、上位のスコープで管理し、ここではその存在を前提とする。
# 今回は、このStep10内で完結するように初期化。
if 'fig_loss' not in locals() or fig_loss is None or not plt.fignum_exists(fig_loss.number):
    fig_loss, ax_loss = plt.subplots(1, 3, figsize=(18, 5))
    print("INFO: Step 10 - Loss plot figure initialized.")
# -------------------------------------------------------------

print(f"INFO: Starting training and evaluation loop for {EPOCHS} epochs.")
for ep in range(1, EPOCHS + 1):
    # --- Train Phase ---
    model.train()
    train_run_metrics = {'total_loss': 0.0, 'cls_loss': 0.0, 'reg_loss': 0.0}
    num_train_batches = 0
    train_pbar = tqdm(dl_tr, desc=f"Epoch {ep}/{EPOCHS} [Train]", leave=False)
    for Xb, yh_hard, ys_soft, y_reg_target, _ in train_pbar: # t_indices は訓練では未使用
        Xb, yh_hard, ys_soft, y_reg_target = Xb.to(device), yh_hard.to(device), ys_soft.to(device), y_reg_target.to(device)
        
        opt.zero_grad()
        logits, reg_pred = model(Xb)
        loss, lcls, lreg = model.compute_loss(logits, yh_hard, ys_soft, reg_pred, y_reg_target)
        loss.backward()
        opt.step()
        
        train_run_metrics['total_loss'] += loss.item()
        train_run_metrics['cls_loss'] += lcls # .item() は compute_loss 内で行われている前提
        train_run_metrics['reg_loss'] += lreg # .item() は compute_loss 内で行われている前提
        num_train_batches += 1
        train_pbar.set_postfix({k_train: v_train / num_train_batches for k_train, v_train in train_run_metrics.items()})

    if hasattr(sched, 'step'): # スケジューラが存在し、stepメソッドを持つ場合
        sched.step()
    
    history['train_total'].append(train_run_metrics['total_loss'] / num_train_batches if num_train_batches > 0 else 0)
    history['train_cls'].append(train_run_metrics['cls_loss'] / num_train_batches if num_train_batches > 0 else 0)
    history['train_reg'].append(train_run_metrics['reg_loss'] / num_train_batches if num_train_batches > 0 else 0)

    # --- Validation Phase ---
    if dl_val: # dl_val が None でないことを確認
        model.eval()
        val_run_metrics = {'total_loss': 0.0, 'cls_loss': 0.0, 'reg_loss': 0.0}
        num_val_batches = 0
        y_true_val_hard_all, probs_val_all, t_indices_val_all = [], [], []
        
        val_pbar = tqdm(dl_val, desc=f"Epoch {ep}/{EPOCHS} [Val]", leave=False)
        with torch.no_grad():
            for Xb_val, yh_hard_val, ys_soft_val, y_reg_target_val, t_idx_val in val_pbar:
                Xb_val, yh_hard_val, ys_soft_val, y_reg_target_val = Xb_val.to(device), yh_hard_val.to(device), ys_soft_val.to(device), y_reg_target_val.to(device)
                
                logits_val, reg_pred_val = model(Xb_val)
                loss_val, lcls_val, lreg_val = model.compute_loss(logits_val, yh_hard_val, ys_soft_val, reg_pred_val, y_reg_target_val)
                
                val_run_metrics['total_loss'] += loss_val.item()
                val_run_metrics['cls_loss'] += lcls_val
                val_run_metrics['reg_loss'] += lreg_val
                num_val_batches += 1
                val_pbar.set_postfix({k_val: v_val / num_val_batches for k_val, v_val in val_run_metrics.items()})

                probs_val_all.append(torch.softmax(logits_val, dim=-1).cpu().numpy())
                y_true_val_hard_all.append(yh_hard_val.cpu().numpy())
                t_indices_val_all.append(t_idx_val.cpu().numpy())

        current_epoch_val_total_loss = val_run_metrics['total_loss'] / num_val_batches if num_val_batches > 0 else float('inf')
        history['val_total'].append(current_epoch_val_total_loss)
        history['val_cls'].append(val_run_metrics['cls_loss'] / num_val_batches if num_val_batches > 0 else 0)
        history['val_reg'].append(val_run_metrics['reg_loss'] / num_val_batches if num_val_batches > 0 else 0)

        if current_epoch_val_total_loss < min_val_loss:
            min_val_loss = current_epoch_val_total_loss
            torch.save(model.state_dict(), best_model_path)
            best_epoch = ep
            # tqdmを使っている場合、printがバーを乱すことがあるため、ループ外での最終報告を推奨
            # print(f"\nEpoch {ep}: New best model saved! Val Total Loss: {min_val_loss:.4f}") 
        
        # --- Validation Metrics & Future Returns Analysis ---
        if num_val_batches > 0:
            y_true_np = np.concatenate(y_true_val_hard_all)
            probs_np = np.concatenate(probs_val_all)
            val_original_indices_np = np.concatenate(t_indices_val_all)

            # ファイル名はエポック番号を3桁ゼロ埋めにするなど工夫するとソートしやすい
            metrics_filepath = os.path.join(output_dir, f"validation_metrics_ep{ep:03d}.txt")
            returns_filepath = os.path.join(output_dir, f"validation_future_returns_ep{ep:03d}.txt")

            with open(metrics_filepath, "w") as f_metrics, open(returns_filepath, "w") as f_returns:
                f_metrics.write(f"Epoch {ep} - Validation Set Metrics\n")
                f_returns.write(f"Epoch {ep} - Validation Set Future Returns Analysis\n")
                
                f_metrics.write("="*50 + "\nOverall Class-Specific AUCs (Threshold-Independent):\n")
                for class_idx, class_name in enumerate(target_names):
                    if class_idx == 0: continue # Skip 'Other' class for this specific AUC reporting
                    if np.any(y_true_np == class_idx) and np.any(y_true_np != class_idx):
                        try:
                            roc_auc = roc_auc_score((y_true_np == class_idx).astype(int), probs_np[:, class_idx])
                            precision, recall, _ = precision_recall_curve((y_true_np == class_idx).astype(int), probs_np[:, class_idx])
                            pr_auc = auc(recall, precision)
                            f_metrics.write(f"  {class_name} - ROC-AUC: {roc_auc:.4f}, PR-AUC: {pr_auc:.4f}\n")
                        except ValueError as e_auc: # ハンドリングを追加
                             f_metrics.write(f"  {class_name} - ROC-AUC: Error ({e_auc}), PR-AUC: Error\n")
                    else:
                        f_metrics.write(f"  {class_name} - ROC-AUC: N/A, PR-AUC: N/A (Insufficient class diversity or no samples for this class)\n")
                f_metrics.write("="*50 + "\n\n")

                for th in thresholds:
                    f_metrics.write(f"--- Metrics for Threshold: {th:.2f} ---\n")
                    f_returns.write(f"--- Returns Analysis for Threshold: {th:.2f} ---\n")

                    # get_thresholded_predictions_final は事前に定義されている前提
                    y_pred_at_threshold = get_thresholded_predictions_final(probs_np, th)
                    
                    cm = confusion_matrix(y_true_np, y_pred_at_threshold, labels=[0,1,2])
                    f_metrics.write("Confusion Matrix:\n" + np.array2string(cm) + "\n\n")
                    try:
                        cr = classification_report(y_true_np, y_pred_at_threshold, target_names=target_names, labels=[0,1,2], zero_division=0, digits=4)
                        f_metrics.write("Classification Report:\n" + cr + "\n\n")
                    except Exception as e_cr: 
                        f_metrics.write(f"Could not generate Classification Report for threshold {th:.2f}: {e_cr}\n\n")

                    f1_min = f1_score((y_true_np == 1).astype(int), (probs_np[:, 1] > th).astype(int), zero_division=0)
                    f1_max = f1_score((y_true_np == 2).astype(int), (probs_np[:, 2] > th).astype(int), zero_division=0)
                    f_metrics.write(f"Binary F1-score (Min vs Rest) @{th:.2f}: {f1_min:.4f}\n")
                    f_metrics.write(f"Binary F1-score (Max vs Rest) @{th:.2f}: {f1_max:.4f}\n")
                    f_metrics.write("-" * 40 + "\n\n")
                    
                    # log_future_price_change_stats_revised は事前に定義されている前提
                    # df_valid は上位スコープから参照
                    log_future_price_change_stats_revised(f_returns, y_pred_at_threshold, val_original_indices_np, df_valid,
                                                          1, target_names[1], FUTURE_STEPS, th)
                    log_future_price_change_stats_revised(f_returns, y_pred_at_threshold, val_original_indices_np, df_valid,
                                                          2, target_names[2], FUTURE_STEPS, th)
            
            # tqdmを使っている場合、printがバーを乱すことがあるため、ループ外での最終報告を推奨
            # if ep < EPOCHS: print() # ループの最後以外で改行を入れるなど工夫
    else: # dl_val がない場合
        history['val_total'].append(float('inf')) # val_lossは無限大として扱う
        history['val_cls'].append(0)
        history['val_reg'].append(0)

    # --- Plot Losses ---
    clear_output(wait=True) # Jupyter Notebook/Lab環境でプロットを更新表示
    for i, key_suffix in enumerate(['total', 'cls', 'reg']):
        ax = ax_loss[i] # ax_loss[i] を ax に代入
        ax.clear()
        ax.plot(history[f'train_{key_suffix}'], label=f'train_{key_suffix}')
        # valのロスが記録されているか、かつinfでない場合のみプロット
        if dl_val and len(history[f'val_{key_suffix}']) > 0 and not np.all(np.isinf(history[f'val_{key_suffix}'])):
             ax.plot(history[f'val_{key_suffix}'], label=f'val_{key_suffix}')
        ax.set_title(f"{key_suffix.capitalize()} Loss (Epoch {ep})")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Average Loss")
        ax.legend()
    fig_loss.tight_layout()
    display(fig_loss) # Jupyter Notebook/Lab環境で図を表示

# --- End of Training Loop ---
print("\nINFO: Step 10 - Training Loop Finished ---")
if best_epoch != -1:
    print(f"Best model found at epoch {best_epoch} with validation total loss: {min_val_loss:.4f}")
    print(f"Best model weights saved to: {best_model_path}")
else:
    print("No best model was saved (either validation loss did not improve or no validation data was provided).")

# 学習ループ終了後にプロットウィンドウを閉じる場合は以下を有効化
# plt.close(fig_loss)

# Step 11: テスト評価 (詳細評価) - Step10からの独立性を確保したバージョン

# --- 必要なライブラリ (スクリプトの先頭やヘルパー関数定義部でインポート済みと仮定) ---
# import torch
# import numpy as np
# import os
# from tqdm.auto import tqdm # perform_detailed_evaluation内で使用
# (perform_detailed_evaluation およびそのヘルパー関数内で必要な sklearn.metrics なども同様)

# --- このStep11のコードブロックを実行するための前提条件 ---
# 1. モデルクラス定義 (事前にグローバルスコープで定義されていること):
#    TransformerClassifierSoftLabel 
# 2. ヘルパー関数群 (事前にグローバルスコープで定義されていること):
#    get_thresholded_predictions_final(probs_np, threshold_val)
#    calculate_future_price_stats(y_predictions_at_threshold, ...) # 最新版のシグネチャに注意
#    format_future_stats_table_row(pred_type_str, ...)          # 最新版のシグネチャに注意
#    perform_detailed_evaluation(model_instance, ...) 
#      (これらの関数の定義は、以前の回答で提示した「指標出力のためのヘルパー関数 (修正版)」を参照)
# 3. 必須オブジェクト・変数 (以前のステップで準備され、グローバルスコープで利用可能であること):
#   - dl_test: torch.utils.data.DataLoader - テスト用データローダー (Step8で作成)
#   - df_valid: pd.DataFrame - テストデータのインデックスがマッピングできる元のDataFrame (Step5でロード・処理)
#                             (★注意: テストデータがdf_valid由来でない場合、適切なDataFrameを指定してください)
#   - device: str - 計算デバイス (Step3で定義)
#   - output_dir: str - 結果ファイル等を保存するディレクトリパス (Step7で定義)
#   - SEQ_LEN: int - モデル入力のシーケンス長 (Step7で定義)
#
#   注意: 以下の評価パラメータ (TARGET_NAMES_S11, THRESHOLDS_S11, FUTURE_STEPS_S11) は
#         このStep11ブロック内で定義されます。もしStep10の検証評価と設定を共通化したい場合は、
#         これらのパラメータをスクリプトのより上位（共通設定箇所）で定義し、
#         Step10とStep11の両方がそれを参照するように変更することを強く推奨します。
#         これにより、設定の不整合を防ぎ、管理が容易になります。
# ------------------------------------------------------------------------------------

print("\nINFO: Step 11 - Starting Test Set Evaluation.")
print("INFO: This step defines its own evaluation parameters for independence.")
print("      It relies on a saved model specified by 'BEST_MODEL_PATH_S11'.")

# --- Step 11 固有の評価パラメータと設定をここで定義 ---
TARGET_NAMES_S11 = ['Other', 'Min', 'Max']    # クラスラベル名
THRESHOLDS_S11 = [0.25, 0.5, 0.75]        # 評価に使用する確率閾値のリスト
FUTURE_STEPS_S11 = 6                     # 将来の価格変動を何ステップ先まで見るか

# ベストモデルのパスは output_dir (Step7で定義済み想定) と固定ファイル名で構築
# このパスに Step10 などで学習・保存されたモデルが存在することを期待します。
BEST_MODEL_PATH_S11 = os.path.join(output_dir, "best_transformer_model.pth") 

# Step10の実行結果である 'best_epoch' には依存しないため、
# perform_detailed_evaluation に渡すエポック情報は固定の文字列、またはパスから類推できる情報とします。
EPOCH_INFO_FOR_FILE_S11 = "best_model_file" # 例: "best_model_from_file" や "loaded_best_model" など
# ----------------------------------------------------

# --- 必須グローバル変数の存在チェック (主要なもの) ---
essential_globals_s11 = {
    'dl_test': "Test DataLoader", 'df_valid': "Source DataFrame for test context", 
    'device': "Device string", 'output_dir': "Output directory path",
    'SEQ_LEN': "Sequence length parameter", 
    'TransformerClassifierSoftLabel': "Model class definition",
    'perform_detailed_evaluation': "Main evaluation helper function",
    'get_thresholded_predictions_final': "Prediction helper function",
    'calculate_future_price_stats': "Statistics helper function", # 最新の関数名に注意
    'format_future_stats_table_row': "Table formatting helper function" # 最新の関数名に注意
}
all_prerequisites_available_s11 = True
for var, name in essential_globals_s11.items():
    if var not in globals() and var not in locals(): 
        print(f"ERROR: Prerequisite '{name}' (variable/function: {var}) for Step 11 is not defined.")
        all_prerequisites_available_s11 = False

if not all_prerequisites_available_s11:
    raise NameError("Missing one or more crucial prerequisites for Step 11. Evaluation cannot proceed.")

if not os.path.exists(output_dir): 
    os.makedirs(output_dir)
    print(f"INFO: Created output directory for Step 11 results: {output_dir}")
# ---------------------------------------------

# テスト用の新しいモデルインスタンスを生成
model_for_testing_s11 = TransformerClassifierSoftLabel().to(device) # 学習時と同じパラメータで初期化
evaluation_performed_s11 = False

if dl_test is not None:
    if os.path.exists(BEST_MODEL_PATH_S11):
        print(f"INFO: Found model file at '{BEST_MODEL_PATH_S11}'. Attempting to load weights...")
        try:
            model_for_testing_s11.load_state_dict(torch.load(BEST_MODEL_PATH_S11, map_location=device))
            print(f"INFO: Successfully loaded model weights from '{BEST_MODEL_PATH_S11}'.")
            
            # perform_detailed_evaluation 関数を呼び出し
            # この関数は、必要なソフトラベルスコアカラム名 (s_min_raw_score_col_nameなど) を
            # original_df_for_lookup (ここでは df_valid) が含んでいることを期待します。
            # 必要に応じて、perform_detailed_evaluationの呼び出し時にカラム名を指定してください。
            perform_detailed_evaluation(
                model_instance=model_for_testing_s11,
                dataloader_to_eval=dl_test,
                original_df_for_lookup=df_valid, # ★重要: テストデータがdf_valid由来でない場合は、ここで適切なDataFrameを指定
                device_to_use=device,
                output_path=output_dir,
                file_prefix_str="test_independent_eval", # ファイル名接頭辞
                class_target_names=TARGET_NAMES_S11,    # Step11ローカル定義を使用
                prob_thresholds=THRESHOLDS_S11,       # Step11ローカル定義を使用
                num_future_steps=FUTURE_STEPS_S11,    # Step11ローカル定義を使用
                epoch_info_str=EPOCH_INFO_FOR_FILE_S11  # Step11ローカル定義の固定情報
                # s_min_raw_score_col_name='s_min_raw_score', # 必要なら指定
                # s_max_raw_score_col_name='s_max_raw_score'  # 必要なら指定
            )
            evaluation_performed_s11 = True
        except Exception as e:
            print(f"ERROR: Failed to load or evaluate model from '{BEST_MODEL_PATH_S11}'. Error: {e}")
            import traceback
            traceback.print_exc()
    else:
        print(f"ERROR: Specified model path '{BEST_MODEL_PATH_S11}' does not exist.")
        print("INFO: Cannot perform evaluation. Please ensure a trained model is saved at this path,")
        print("      or update 'BEST_MODEL_PATH_S11' (defined at the start of Step 11) to the correct file path.")
    
    if not evaluation_performed_s11:
        print("INFO: Step 11 - No evaluation was successfully performed with a trained model.")
else:
    print("INFO: Test data loader (dl_test) is None. Skipping test evaluation.")

print("INFO: Step 11 - Test Evaluation Phase Finished.")