# 3. การเทรน Agent (Agent Training)
## ขั้นตอนการเทรน RL Agent สำหรับ Crypto Trading

### เป้าหมาย:
- โหลด Agent ที่สร้างไว้
- เทรน Agent ด้วยข้อมูล Training
- Validate ผลการเทรนด้วยข้อมูล Validation
- บันทึก Model ที่เทรนแล้ว
- วิเคราะห์ Learning Progress

## Cell 1: Import Libraries และโหลดข้อมูลก่อนหน้า

In [None]:
import sys
import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import torch
import time
from datetime import datetime

# FinRL imports
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
from finrl.agents.stablebaselines3.models import DRLAgent
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
from stable_baselines3.common.monitor import Monitor

# Import config
from config import *

# Setup directories
PROCESSED_DIR = "processed_data"
MODEL_DIR = "models"
AGENT_DIR = "agents"
LOGS_DIR = "logs"
TENSORBOARD_DIR = "tensorboard_logs"

for dir_name in [MODEL_DIR, LOGS_DIR, TENSORBOARD_DIR]:
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)

print("📁 Setup directories completed")
print(f"🚀 Starting Agent Training Process")

## Cell 2: โหลด Environment และ Agent

In [None]:
# ฟังก์ชันโหลดข้อมูลและ environment
def load_training_setup():
    print("📂 Loading training setup...")
    try:
        pickle_file = os.path.join(PROCESSED_DIR, "processed_crypto_data.pkl")
        with open(pickle_file, 'rb') as f:
            df = pickle.load(f)
        print(f"✅ Loaded processed data from {pickle_file}")
    except:
        csv_file = os.path.join(PROCESSED_DIR, "processed_crypto_data.csv")
        df = pd.read_csv(csv_file)
        df['timestamp'] = pd.to_datetime(df['timestamp'])
        print(f"✅ Loaded processed data from {csv_file}")
    env_config_file = os.path.join(AGENT_DIR, "environment_config.pkl")
    with open(env_config_file, 'rb') as f:
        env_config = pickle.load(f)
    print(f"✅ Loaded environment config")
    agent_config_file = os.path.join(AGENT_DIR, "agent_configs.pkl")
    with open(agent_config_file, 'rb') as f:
        agent_configs = pickle.load(f)
    print(f"✅ Loaded agent configs")
    agent_info_file = os.path.join(AGENT_DIR, "agent_info.pkl")
    with open(agent_info_file, 'rb') as f:
        agent_info = pickle.load(f)
    print(f"✅ Loaded agent info")
    return df, env_config, agent_configs, agent_info

def recreate_environments(df, env_config):
    print("🏛️ Recreating environments...")
    total_len = len(df)
    train_size = int(total_len * 0.7)
    val_size = int(total_len * 0.15)
    train_df = df.iloc[:train_size].reset_index(drop=True)
    val_df = df.iloc[train_size:train_size + val_size].reset_index(drop=True)
    test_df = df.iloc[train_size + val_size:].reset_index(drop=True)
    for data in [train_df, val_df, test_df]:
        data['timestamp'] = pd.to_datetime(data['timestamp'])
        data['date'] = data['timestamp'].dt.date
        data.sort_values(['date', 'tic'], inplace=True)
        data.reset_index(drop=True, inplace=True)
    env_kwargs = env_config['env_kwargs']
    train_env = StockTradingEnv(df=train_df, **env_kwargs)
    val_env = StockTradingEnv(df=val_df, **env_kwargs)
    test_env = StockTradingEnv(df=test_df, **env_kwargs)
    train_env = Monitor(train_env, os.path.join(LOGS_DIR, "train_monitor"))
    val_env = Monitor(val_env, os.path.join(LOGS_DIR, "val_monitor"))
    print("✅ Environments recreated and wrapped with Monitor")
    return train_env, val_env, test_env, train_df, val_df, test_df

# โหลดข้อมูลและสร้าง environments
df, env_config, agent_configs, agent_info = load_training_setup()
train_env, val_env, test_env, train_df, val_df, test_df = recreate_environments(df, env_config)
print(f"\n📊 Training setup completed:")
print(f"  Train data: {len(train_df)} rows")
print(f"  Val data: {len(val_df)} rows")
print(f"  Test data: {len(test_df)} rows")
print(f"  Model: {agent_info['model_name']}")
print(f"  Device: {agent_info['device']}")

## Cell 3: สร้าง Callbacks และ Training Setup

In [None]:
def create_training_callbacks(val_env, model_name):
    print("🔧 Creating training callbacks...")
    eval_callback = EvalCallback(
        val_env,
        best_model_save_path=os.path.join(MODEL_DIR, f"best_{model_name.lower()}_model"),
        log_path=os.path.join(LOGS_DIR, f"eval_{model_name.lower()}"),
        eval_freq=5000,
        n_eval_episodes=5,
        deterministic=True,
        render=False,
        verbose=1
    )
    reward_threshold_callback = StopTrainingOnRewardThreshold(
        reward_threshold=1000,
        verbose=1
    )
    callbacks = [eval_callback, reward_threshold_callback]
    print("✅ Training callbacks created")
    return callbacks

def setup_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
        print(f"✅ Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    else:
        device = torch.device("cpu")
        print("ℹ️ Using CPU")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0" if torch.cuda.is_available() else "-1"
    return device

device = setup_device()
for model_name in ['PPO', 'A2C', 'DDPG', 'SAC']:
    if model_name in agent_configs:
        agent_configs[model_name]['device'] = device
model_name = agent_info['model_name']
callbacks = create_training_callbacks(val_env, model_name)
print(f"\n🔧 Training setup completed for {model_name}")

## Cell 4: สร้างและเทรน Agent

In [None]:
# ฟังก์ชันสร้างและเทรน Agent
def create_and_train_agent(train_env, model_name, agent_configs, callbacks):
    print(f"🤖 Creating and training {model_name} agent...")
    start_time = time.time()
    try:
        agent = DRLAgent(env=train_env)
        model_params = agent_configs[model_name].copy()
        training_config = agent_configs['TRAINING']
        print(f"🧠 Model parameters:")
        for key, value in model_params.items():
            print(f"  {key}: {value}")
        print(f"\n⏳ Training configuration:")
        for key, value in training_config.items():
            print(f"  {key}: {value}")
        model = agent.get_model(model_name.lower(), model_kwargs=model_params)
        print(f"✅ {model_name} model created successfully")
        print(f"\n🏃 Starting training...")
        print(f"📊 Training timesteps: {training_config['total_timesteps']:,}")
        trained_model = agent.train_model(
            model=model,
            tb_log_name=training_config['tb_log_name'],
            total_timesteps=training_config['total_timesteps'],
            callback=callbacks
        )
        training_time = time.time() - start_time
        print(f"\n✅ Training completed successfully!")
        print(f"⏱️ Training time: {training_time/60:.2f} minutes")
        model_path = os.path.join(MODEL_DIR, f"trained_{model_name.lower()}_model")
        trained_model.save(model_path)
        print(f"💾 Model saved to {model_path}")
        training_info = {
            'model_name': model_name,
            'training_time_minutes': training_time/60,
            'total_timesteps': training_config['total_timesteps'],
            'training_start': datetime.now().isoformat(),
            'model_params': model_params,
            'training_config': training_config,
            'model_path': model_path
        }
        training_info_file = os.path.join(MODEL_DIR, f"training_info_{model_name.lower()}.pkl")
        with open(training_info_file, 'wb') as f:
            pickle.dump(training_info, f)
        print(f"💾 Training info saved to {training_info_file}")
        return trained_model, training_info
    except Exception as e:
        print(f"❌ Error during training: {str(e)}")
        return None, None

trained_model, training_info = create_and_train_agent(train_env, model_name, agent_configs, callbacks)
print(f"\n🎉 {model_name} agent training completed!")
print(f"📊 Training summary:")
print(f"  Model: {training_info['model_name']}")
print(f"  Training time: {training_info['training_time_minutes']:.2f} minutes")
print(f"  Total timesteps: {training_info['total_timesteps']:,}")
print(f"  Model saved to: {training_info['model_path']}")

## Cell 5: การประเมินผล Training และ Validation

In [None]:
# ฟังก์ชันประเมินผลและวิเคราะห์ Learning Progress
def evaluate_trained_model(trained_model, val_env, test_env):
    print("📊 Evaluating trained model...")
    results = {}
    # Validation
    print("\n🔍 Validation evaluation...")
    try:
        val_account_value, val_actions = DRLAgent.DRL_prediction(
            model=trained_model,
            environment=val_env
        )
        val_initial = INITIAL_AMOUNT
        val_final = val_account_value['account_value'].iloc[-1]
        val_return = (val_final - val_initial) / val_initial * 100
        results['validation'] = {
            'initial_value': val_initial,
            'final_value': val_final,
            'total_return': val_return,
            'account_values': val_account_value,
            'actions': val_actions
        }
        print(f"✅ Validation completed")
        print(f"💰 Initial: ${val_initial:,.2f}")
        print(f"💰 Final: ${val_final:,.2f}")
        print(f"📈 Return: {val_return:.2f}%")
    except Exception as e:
        print(f"❌ Validation evaluation failed: {str(e)}")
        results['validation'] = None
    # Test (preview)
    print("\n🔍 Test evaluation (preview)...")
    try:
        test_account_value, test_actions = DRLAgent.DRL_prediction(
            model=trained_model,
            environment=test_env
        )
        test_initial = INITIAL_AMOUNT
        test_final = test_account_value['account_value'].iloc[-1]
        test_return = (test_final - test_initial) / test_initial * 100
        results['test'] = {
            'initial_value': test_initial,
            'final_value': test_final,
            'total_return': test_return,
            'account_values': test_account_value,
            'actions': test_actions
        }
        print(f"✅ Test evaluation completed")
        print(f"💰 Initial: ${test_initial:,.2f}")
        print(f"💰 Final: ${test_final:,.2f}")
        print(f"📈 Return: {test_return:.2f}%")
    except Exception as e:
        print(f"❌ Test evaluation failed: {str(e)}")
        results['test'] = None
    return results

def plot_training_progress(results, model_name):
    print("📊 Creating training progress plots...")
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    # Plot 1: Validation Portfolio Value
    if results['validation'] is not None:
        val_data = results['validation']['account_values']
        axes[0, 0].plot(val_data['account_value'], color='blue', linewidth=2)
        axes[0, 0].axhline(y=INITIAL_AMOUNT, color='red', linestyle='--', alpha=0.7, label='Initial Value')
        axes[0, 0].set_title('Validation Portfolio Value')
        axes[0, 0].set_ylabel('Portfolio Value ($)')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
    else:
        axes[0, 0].text(0.5, 0.5, 'Validation Data\nNot Available', ha='center', va='center', transform=axes[0, 0].transAxes)
    # Plot 2: Test Portfolio Value
    if results['test'] is not None:
        test_data = results['test']['account_values']
        axes[0, 1].plot(test_data['account_value'], color='green', linewidth=2)
        axes[0, 1].axhline(y=INITIAL_AMOUNT, color='red', linestyle='--', alpha=0.7, label='Initial Value')
        axes[0, 1].set_title('Test Portfolio Value')
        axes[0, 1].set_ylabel('Portfolio Value ($)')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
    else:
        axes[0, 1].text(0.5, 0.5, 'Test Data\nNot Available', ha='center', va='center', transform=axes[0, 1].transAxes)
    # Plot 3: Returns Comparison
    returns_data = []
    labels = []
    colors = []
    if results['validation'] is not None:
        returns_data.append(results['validation']['total_return'])
        labels.append('Validation')
        colors.append('blue')
    if results['test'] is not None:
        returns_data.append(results['test']['total_return'])
        labels.append('Test')
        colors.append('green')
    if returns_data:
        bars = axes[1, 0].bar(labels, returns_data, color=colors, alpha=0.7)
        axes[1, 0].set_title('Returns Comparison')
        axes[1, 0].set_ylabel('Return (%)')
        axes[1, 0].axhline(y=0, color='black', linestyle='-', alpha=0.3)
        for bar, value in zip(bars, returns_data):
            height = bar.get_height()
            axes[1, 0].text(bar.get_x() + bar.get_width()/2., height + (0.5 if height > 0 else -1.5), f'{value:.2f}%', ha='center', va='bottom' if height > 0 else 'top')
    else:
        axes[1, 0].text(0.5, 0.5, 'No Return Data\nAvailable', ha='center', va='center', transform=axes[1, 0].transAxes)
    # Plot 4: Training Summary
    axes[1, 1].axis('off')
    summary_text = f"Model: {model_name}\n"
    axes[1, 1].text(0.1, 0.5, summary_text, fontsize=14, va='center')
    plt.tight_layout()
    plt.show()
    return fig

# ประเมินผลและวิเคราะห์ Learning Progress
results = evaluate_trained_model(trained_model, val_env, test_env)
fig = plot_training_progress(results, model_name)
print("\n✅ Training evaluation and analysis completed!")