# Airline Pricing Optimization with TD3
### An Interactive Reinforcement Learning Solution
*Clear objective and visual hierarchy.* This notebook uses the TD3 algorithm to find optimal pricing strategies, presented in an interactive format.

## 1. Setup and Configuration
*Collapsible configuration section with interactive display.* Configure paths, hyperparameters, and system settings here.


### 1.1 CSS Styling
 *Adds some basic styling for visual elements.*

In [None]:
%%html
 <style>
 .config-box {
     border: 1px solid #e0e0e0;
     border-left: 5px solid #007bff;
     border-radius: 4px;
     padding: 15px;
     margin: 15px 0;
     background-color: #f8f9fa;
     font-family: monospace;
 }
 .config-box h4 {
     margin-top: 0;
     color: #0056b3;
     border-bottom: 1px solid #ccc;
     padding-bottom: 5px;
 }
 .output-box {
     border: 1px solid #ccc;
     border-radius: 4px;
     padding: 10px;
     margin: 10px 0;
     background-color: #fff;
 }
 </style>

### 1.2 Core Parameters Definition
*Define the `Config` dataclass holding all parameters.*

In [None]:
from dataclasses import dataclass, field
from typing import Tuple, Dict, Any, Optional
import os

@dataclass
class Config:
    """Configuration parameters for the Airline Pricing TD3 experiment"""
    # --- Data paths ---
    data_paths: Dict[str, str] = field(default_factory=lambda: {
        'historical_data': "./data/cleaned_dataset.csv",
        'fuel_data': "./data/petroluem_prod_consumption.csv", 
        'aircraft_specs': "./data/airplane_prices_dataset.csv",
        'weather_data_detailed': "./data/weather_data.csv"
    })
    output_dir: str = "./output_td3_interactive" # Separate output dir

    # --- Experiment parameters ---
    seed: int = 42
    n_episodes: int = 200
    batch_size: int = 100
    buffer_size: int = 100000

    # --- TD3 Specific Hyperparameters ---
    actor_lr: float = 1e-4
    critic_lr: float = 1e-3
    gamma: float = 0.99
    tau: float = 0.005
    policy_noise: float = 0.2
    noise_clip: float = 0.5
    policy_freq: int = 2
    exploration_noise: float = 0.1

    # --- Environment parameters ---
    seats_capacity: int = 150
    simulation_length_days: int = 90
    price_range: Tuple[float, float] = (1000.0, 50000.0) # Placeholder

    # --- Agent Network Architecture ---
    actor_hidden_dims: list = field(default_factory=lambda: [256, 128])
    critic_hidden_dims: list = field(default_factory=lambda: [256, 128])

    # --- Training Control ---
    start_timesteps: int = 5000 # Random exploration steps
    validation_freq: int = 20
    patience: int = 15
    use_gpu: bool = True
    max_action_value: float = 1.0 # Corresponds to Tanh output [-1, 1]

### 1.3 Instantiate & Display Configuration
 *Interactive configuration display.* Creates the config object and shows its current values in a formatted box.


In [None]:
from IPython.display import display, HTML
import pprint # For nicely formatting the config display

config = Config() # Create the actual config instance

# Create HTML representation
config_html = (
    f'<div class="config-box">'
    f'<h4>Active Configuration:</h4>'
    f'<pre>{pprint.pformat(config.__dict__, indent=2)}</pre>'
    f'</div>'
)
display(HTML(config_html))

### 1.4 Import Libraries & System Setup
*Import necessary libraries and initialize seeds, device, logging.*

In [None]:
import time
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder, OrdinalEncoder
from collections import deque
import random
import logging
from tqdm.notebook import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter as writer
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from ipywidgets import interact, Dropdown, Button, FloatSlider, IntSlider, Layout, VBox, HBox, FloatLogSlider
import holidays
import math
import copy
import warnings

# Filter simple warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning)

# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')

# --- Plot Style ---
sns.set(style="whitegrid", palette="muted")
plt.rcParams['figure.figsize'] = (10, 5) # Slightly smaller default plots

# --- Random Seeds ---
def set_seeds(seed_value: int):
    random.seed(seed_value); np.random.seed(seed_value); torch.manual_seed(seed_value)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed_value)
set_seeds(config.seed)
logging.info(f"Seeds set to {config.seed}")

# --- Device Configuration ---
if config.use_gpu and torch.cuda.is_available(): device = torch.device("cuda")
else: device = torch.device("cpu")
logging.info(f"Using device: {device}")

# --- Output Directory ---
os.makedirs(config.output_dir, exist_ok=True)
config.model_path_actor = os.path.join(config.output_dir, "td3_actor_best.pth")
config.model_path_critic = os.path.join(config.output_dir, "td3_critic_best.pth")
logging.info(f"Output Dir: {config.output_dir}")

## 2. Data Pipeline
---
*Interactive data exploration.* Load, process, and interactively explore the flight data.
 

### 2.1 DataProcessor Class
 *Handles data loading, cleaning, merging, feature engineering, and scaling.*

In [None]:
class DataProcessor:
    """
    Handles the loading, preprocessing, feature engineering, and scaling
    of various data sources for the airline pricing environment.

    Key Steps:
    1. Loads raw data (historical flights, aircraft specs, weather).
    2. Fetches holidays dynamically.
    3. Merges external data onto the historical flight data.
    4. Engineers time-based and statistical features.
    5. Handles missing values using defined strategies.
    6. Fits encoders and scalers during the `fit` phase using training data.
    7. Applies learned transformations during the `transform` phase.
    8. Provides utilities for price scaling.
    """
    def __init__(self, config: Config):
        """
        Initializes the DataProcessor with configuration and sets up scalers/encoders.

        Args:
            config: A dataclass or object containing configuration parameters,
                    including data paths and processing settings.
        """
        self.config = config
        self.paths = config.data_paths # Dictionary of data file paths

        # --- Scalers ---
        # Initialize scalers for numerical features. They will be fitted later.
        # Price (Fare) Scaler: Maps fares to [0, 1] range, matching Actor's Tanh output after shifting.
        self.price_scaler = MinMaxScaler(feature_range=(0, 1))
        # Days Left Scaler: Maps days left to [0, 1] range.
        self.days_left_scaler = MinMaxScaler()
        # Standard Scalers: Standardize features to have mean 0 and std dev 1.
        self.duration_scaler = StandardScaler()       # For 'Duration_in_hours'
        self.temp_scaler = StandardScaler()           # For 'Weather_Temp_C'
        self.humidity_scaler = StandardScaler()       # For 'Weather_Humidity'
        self.age_scaler = StandardScaler()            # For 'Aircraft_Age'
        self.fuel_eff_scaler = StandardScaler()       # For 'Fuel Consumption (L/hour)'
        self.fare_ma7_scaler = StandardScaler()       # For 'Price_MA7' (engineered feature)
        self.fare_lag1_scaler = StandardScaler()      # For 'Price_Lag1' (engineered feature)

        # --- Encoders ---
        # Initialize encoders for categorical features.
        # Label Encoders: Assign a unique integer to each category.
        self.route_encoder = LabelEncoder()           # For 'Route' (e.g., 'Delhi-Mumbai')
        self.airline_encoder = LabelEncoder()         # For 'Airline' name
        self.class_encoder = LabelEncoder()           # For 'Class' (e.g., 'Economy')
        self.source_encoder = LabelEncoder()          # For 'Source' city
        self.destination_encoder = LabelEncoder()     # For 'Destination' city
        self.aircraft_model_encoder = LabelEncoder()  # For 'AircraftType' / 'Model'
        self.weather_cond_encoder = LabelEncoder()    # For 'Weather_Condition' text
        self.season_encoder = LabelEncoder()          # For generated 'Season'

        # Ordinal Encoders: Assign integers based on a predefined order.
        # Handle unknown values by mapping them to -1.
        self.stops_encoder = OrdinalEncoder(
            categories=[['non-stop', '1-stop', '2-stops', '3-stops', '4-stops']], # Explicit order
            handle_unknown='use_encoded_value', unknown_value=-1
        )
        self.journey_day_encoder = OrdinalEncoder(
            categories=[['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']],
            handle_unknown='use_encoded_value', unknown_value=-1
        )
        self.dep_time_encoder = OrdinalEncoder(
            categories=[['Before 6 AM', '6 AM - 12 PM', '12 PM - 6 PM', 'After 6 PM']],
            handle_unknown='use_encoded_value', unknown_value=-1
        )
        self.arr_time_encoder = OrdinalEncoder(
             categories=[['Before 6 AM', '6 AM - 12 PM', '12 PM - 6 PM', 'After 6 PM']],
             handle_unknown='use_encoded_value', unknown_value=-1
        )

        # --- Fitted Parameters & State ---
        # These store values learned during the 'fit' phase or other necessary info.
        self.fuel_price_mean = 2.5  # Fallback value as real price data is unavailable
        self.historical_fare_stats: Dict[Tuple[str, str], Dict[str, float]] = {} # Stores {'Route-Airline': {'mean': X, 'median': Y}}
        self.global_fare_mean: float = 0.0 # Fallback historical mean fare across all data
        self.is_fitted: bool = False # Flag to track if `fit` has been called

        # --- Data for Environment ---
        # Store mappings needed by the simulation environment (e.g., for cost calculation).
        self.aircraft_fuel_rates: Dict[str, float] = {} # Maps Aircraft Model -> Fuel Consumption Rate
        self.maintenance_costs: Dict[str, float] = {} # Maps Aircraft Model -> Hourly Maintenance Cost

    def _load_data(self) -> Dict[str, pd.DataFrame]:
        """
        Loads dataframes specified in the config's data_paths.
        Handles file not found errors for optional files.

        Returns:
            A dictionary where keys are data names (e.g., 'historical_data')
            and values are the loaded pandas DataFrames.
        """
        data = {}
        logging.info("Loading data files specified in configuration...")
        for name, path in self.paths.items():
            # Skip holiday data path if present, as it's fetched dynamically
            if name == 'holiday_data':
                continue
            try:
                # Check if the file exists
                if os.path.exists(path):
                    data[name] = pd.read_csv(path)
                    logging.info(f"  Successfully loaded '{name}' from '{path}' (Shape: {data[name].shape})")
                else:
                    # Historical data is essential, raise error if missing
                    if name == 'historical_data':
                         logging.error(f"CRITICAL ERROR: Required historical data file not found at {path}.")
                         raise FileNotFoundError(f"Required file missing: {path}")
                    # Other files might be optional
                    else:
                         logging.warning(f"Optional file not found: {path}. Skipping '{name}'.")
                         data[name] = pd.DataFrame() # Assign empty DataFrame as placeholder
            except Exception as e:
                # Catch other potential errors during loading (e.g., parsing errors)
                logging.error(f"Error loading '{name}' from {path}: {e}")
                data[name] = pd.DataFrame() # Assign empty DataFrame on error
        return data

    def _fetch_holidays(self, years: list[int]) -> pd.DataFrame:
        """
        Fetches Indian national holidays for the specified year range using the `holidays` library.

        Args:
            years: A list of years for which to fetch holidays.

        Returns:
            A pandas DataFrame with 'Date' and 'HolidayName' columns,
            or an empty DataFrame if fetching fails or no holidays are found.
        """
        logging.info(f"Fetching Indian national holidays for years: {years}...")
        try:
            # Use the holidays library for India ('IN')
            in_holidays = holidays.country_holidays('IN', years=years)
            # Check if the result is empty
            if not in_holidays:
                logging.warning(f"No holidays found for years {years}.")
                return pd.DataFrame({'Date': pd.to_datetime([]), 'HolidayName': []}) # Ensure correct dtypes

            # Convert the dictionary of holidays to a list of dicts
            holiday_list = [{'Date': date, 'HolidayName': name}
                            for date, name in sorted(in_holidays.items())]
            # Create DataFrame and ensure Date column is datetime
            holiday_df = pd.DataFrame(holiday_list)
            holiday_df['Date'] = pd.to_datetime(holiday_df['Date'])
            logging.info(f"  Successfully fetched {len(holiday_df)} holiday entries.")
            return holiday_df
        except Exception as e:
            # Handle potential errors during the fetch process
            logging.error(f"Error fetching holidays using 'holidays' library: {e}")
            return pd.DataFrame({'Date': pd.to_datetime([]), 'HolidayName': []}) # Return empty DF

    def fit(self, data: Dict[str, pd.DataFrame]):
        """
        Fits the encoders, scalers, and calculates imputation values using the provided training data.
        This method should only be called once on the training dataset.

        Args:
            data: A dictionary containing the raw dataframes loaded by `_load_data`.
        """
        logging.info("Starting DataProcessor fitting phase (learning transformations)...")

        # --- Get Primary Data ---
        hist_data = data.get('historical_data') # Expects 'cleaned_dataset.csv'
        aircraft_specs = data.get('aircraft_specs', pd.DataFrame())
        weather_data = data.get('weather_data_detailed', pd.DataFrame())

        if hist_data is None or hist_data.empty:
            raise ValueError("Cannot fit: Historical data ('cleaned_dataset.csv') is missing or empty.")

        # --- Initial Cleaning & Preparation ---
        logging.info("Performing initial cleaning and preparation...")
        # Rename 'Fare' to 'Price' for internal consistency
        hist_data.rename(columns={'Fare': 'Price'}, inplace=True, errors='ignore')
        if 'Price' not in hist_data.columns: raise ValueError("Missing 'Price' (originally 'Fare') column.")
        # Convert 'Date' column to datetime objects, coercing errors to NaT
        hist_data['Date'] = pd.to_datetime(hist_data['Date'], errors='coerce')
        # Drop rows where essential columns (Date, Source, Dest, Price) are missing
        initial_rows = len(hist_data)
        hist_data.dropna(subset=['Date', 'Source', 'Destination', 'Price'], inplace=True)
        if initial_rows > len(hist_data): logging.info(f"  Dropped {initial_rows - len(hist_data)} rows with missing essential values.")
        # Create the 'Route' feature by combining Source and Destination
        hist_data['Route'] = hist_data['Source'] + '-' + hist_data['Destination']

        # Prepare weather data dates
        if not weather_data.empty and 'datetime' in weather_data.columns:
            weather_data['datetime'] = pd.to_datetime(weather_data['datetime'], errors='coerce')
            weather_data.dropna(subset=['datetime'], inplace=True)
        else:
             weather_data = pd.DataFrame() # Ensure empty DF if unusable

        # --- Fetch Holidays ---
        # Determine year range from historical data to fetch relevant holidays
        min_year = hist_data['Date'].dt.year.min()
        max_year = hist_data['Date'].dt.year.max()
        holiday_df = self._fetch_holidays(years=list(range(min_year, max_year + 1)))

        # --- Prepare Data for Fitting Transformations ---
        # Create a copy to avoid modifying the original raw data
        fit_data = hist_data.copy()
        # Merge external sources (holidays, aircraft, weather)
        fit_data = self._merge_external_data(fit_data, holiday_df, aircraft_specs, weather_data)
        # Engineer features (like Season, Price MA/Lag) that might be needed for scaling/encoding
        fit_data = self._engineer_features(fit_data)
        # Handle missing values *before* fitting encoders/scalers
        fit_data = self._handle_missing_values(fit_data, is_fitting=True) # is_fitting=True calculates imputation values

        # --- Calculate Historical Fare Statistics ---
        # Compute mean/median fare per route/airline for later use (e.g., demand heuristic)
        logging.info("Calculating historical fare statistics per route/airline...")
        try:
            self.historical_fare_stats = fit_data.groupby(['Route', 'Airline'])['Price'].agg(['mean', 'median']).to_dict('index')
            self.global_fare_mean = fit_data['Price'].mean() # Calculate global fallback mean
            logging.info(f"  Calculated stats for {len(self.historical_fare_stats)} route/airline pairs. Global mean fare: {self.global_fare_mean:.2f}")
        except KeyError:
             logging.error("  Could not group by 'Route' or 'Airline' - columns might be missing after merge/impute.")
             self.global_fare_mean = fit_data['Price'].mean() if 'Price' in fit_data.columns else 0.0


        # --- Fit Encoders ---
        # Learn the mapping from categories to integers for all categorical features
        self._fit_categorical_encoders(fit_data) # Pass the prepared data

        # --- Fit Scalers ---
        # Learn the scaling parameters (min/max or mean/std) for all numerical features
        self._fit_numerical_scalers(fit_data) # Pass the prepared data

        # Set the flag indicating fitting is complete
        self.is_fitted = True
        logging.info("DataProcessor fitting phase completed successfully.")
        return self # Return self for potential chaining

    def transform(self, data: Dict[str, pd.DataFrame], is_validation=False) -> pd.DataFrame:
        """
        Applies the learned transformations (imputation, encoding, scaling)
        to the provided data.

        Args:
            data: A dictionary containing the raw dataframes (similar structure to input for `fit`).
            is_validation: Flag indicating if this is for validation/test data (currently unused, but good practice).

        Returns:
            A pandas DataFrame with all preprocessing steps applied.
        """
        if not self.is_fitted:
            raise RuntimeError("DataProcessor must be fitted using .fit() before calling .transform().")
        logging.info(f"Starting DataProcessor transformation phase (Validation Mode: {is_validation})...")

        # --- Get Primary Data ---
        hist_data = data.get('historical_data')
        aircraft_specs = data.get('aircraft_specs', pd.DataFrame())
        weather_data = data.get('weather_data_detailed', pd.DataFrame())

        if hist_data is None or hist_data.empty:
             logging.error("Cannot transform: Historical data is missing or empty.")
             return pd.DataFrame()

        # --- Initial Cleaning & Preparation ---
        # Perform the same initial steps as in `fit` to ensure consistency
        df = hist_data.copy()
        df.rename(columns={'Fare': 'Price'}, inplace=True, errors='ignore')
        if 'Price' not in df.columns: raise ValueError("Missing 'Price' (originally 'Fare') column.")
        df['Date'] = pd.to_datetime(df['Date'], errors='coerce')
        df.dropna(subset=['Date', 'Source', 'Destination', 'Price'], inplace=True)
        df['Route'] = df['Source'] + '-' + df['Destination']

        # Prepare weather data dates
        if not weather_data.empty and 'datetime' in weather_data.columns:
            weather_data['datetime'] = pd.to_datetime(weather_data['datetime'], errors='coerce')
            weather_data.dropna(subset=['datetime'], inplace=True)
        else: weather_data = pd.DataFrame()

        # --- Fetch Holidays (for the date range in *this* data chunk) ---
        if not df.empty:
             min_year = df['Date'].dt.year.min()
             max_year = df['Date'].dt.year.max()
             holiday_df = self._fetch_holidays(years=list(range(min_year, max_year + 1)))
        else: holiday_df = pd.DataFrame() # Handle empty input df


        # --- Apply Transformations in Order ---
        # 1. Merge external data
        df = self._merge_external_data(df, holiday_df, aircraft_specs, weather_data)
        # 2. Engineer features
        df = self._engineer_features(df)
        # 3. Handle missing values (using stored imputation values from `fit`)
        df = self._handle_missing_values(df, is_fitting=False)
        # 4. Apply learned encoders
        df = self._apply_encoders(df)
        # 5. Apply learned scalers
        df = self._apply_scalers(df)

        logging.info("Data transformation phase completed.")

        # Final check for any remaining NaNs after all steps
        if df.isnull().any().any():
            logging.warning(f"NaNs detected in final transformed data. Review imputation steps.")
            # Consider dropping or force-filling remaining NaNs based on strategy
            # df.fillna(0, inplace=True) # Example: Force fill

        return df.reset_index(drop=True) # Return with clean index

    # --------------------------------------------------------------------------
    # Internal Helper Methods (_merge, _engineer, _handle_missing, _fit*, _apply*)
    # --------------------------------------------------------------------------

    def _merge_external_data(self, base_df: pd.DataFrame, holiday_df: pd.DataFrame,
                             aircraft_df: pd.DataFrame, weather_df: pd.DataFrame) -> pd.DataFrame:
        """Merges holiday, aircraft specs, and weather data onto the base flight data."""
        logging.debug("Merging external data sources...")
        df = base_df.copy()
        # Use normalized date (date part only) for merging day-level data
        df['Date_Norm'] = df['Date'].dt.normalize()

        # --- Merge Holidays ---
        if not holiday_df.empty:
            holiday_df['Date_Norm'] = holiday_df['Date'].dt.normalize()
            # Merge based on date only for national holidays
            df = pd.merge(df, holiday_df[['Date_Norm', 'HolidayName']], on='Date_Norm', how='left')
            df['IsHoliday'] = df['HolidayName'].notna().astype(int)
            df.drop('HolidayName', axis=1, inplace=True, errors='ignore')
            logging.debug(f"  Merged holiday data. Found {df['IsHoliday'].sum()} holiday flags.")
        else:
            df['IsHoliday'] = 0 # Default if no holiday data

        # --- Merge Aircraft Specs ---
        # Assumption: 'AircraftType' in base_df maps to 'Model' in aircraft_df
        key_hist = 'AircraftType'
        key_spec = 'Model'
        # Add placeholder columns first to ensure they exist even if merge fails
        df['Age'] = np.nan
        df['Fuel Consumption (L/hour)'] = np.nan
        df['Hourly Maintenance Cost ($)'] = np.nan

        if not aircraft_df.empty and key_spec in aircraft_df.columns:
            # Pre-process specs (only needs to happen once conceptually, but safe to repeat)
            if 'Production Year' in aircraft_df.columns:
                 current_year = datetime.now().year
                 # Ensure Production Year is numeric before subtracting
                 prod_year_num = pd.to_numeric(aircraft_df['Production Year'], errors='coerce')
                 aircraft_df['Age'] = current_year - prod_year_num
                 # Handle potential future dates or errors, ensure non-negative age
                 aircraft_df['Age'] = aircraft_df['Age'].apply(lambda x: max(0, x) if pd.notnull(x) else np.nan)
            # Store mappings if this is the fitting phase (based on is_fitted flag)
            # (Mappings are now stored in self.aircraft_fuel_rates etc directly in __init__ or fit)

            # Perform merge only if the key exists in the historical data
            if key_hist in df.columns:
                 spec_cols_to_merge = [key_spec, 'Age', 'Fuel Consumption (L/hour)', 'Hourly Maintenance Cost ($)']
                 # Select only columns that actually exist in aircraft_df
                 spec_cols_exist = [col for col in spec_cols_to_merge if col in aircraft_df.columns]

                 # Perform the merge safely, handling potential duplicate columns if keys are same
                 if key_hist == key_spec:
                     # Merge and drop duplicate key column from right DF if names are identical
                     df_merged = pd.merge(df, aircraft_df[spec_cols_exist], on=key_hist, how='left', suffixes=('', '_spec'))
                 else:
                      df_merged = pd.merge(df, aircraft_df[spec_cols_exist], left_on=key_hist, right_on=key_spec, how='left', suffixes=('', '_spec'))
                      # Drop the key from the right dataframe if names differ
                      if key_spec in df_merged.columns: df_merged.drop(key_spec, axis=1, inplace=True, errors='ignore')

                 # Overwrite original columns with merged data where available
                 for col in ['Age', 'Fuel Consumption (L/hour)', 'Hourly Maintenance Cost ($)']:
                      spec_col = col + '_spec'
                      if spec_col in df_merged.columns:
                          # Update original column only where spec data is not NaN
                          df[col] = df_merged[spec_col].combine_first(df[col])
                          # Drop the temporary spec column
                          # df.drop(spec_col, axis=1, inplace=True) # Removed, handled later if needed

                 logging.debug(f"  Merged aircraft specs using '{key_hist}' -> '{key_spec}'.")
            else:
                 logging.warning(f"  Cannot merge aircraft specs: Key '{key_hist}' not found in base data.")
        else:
            logging.warning("  Skipping aircraft specs merge: Specs data missing, empty, or lacks '{key_spec}' column.")


        # --- Merge Weather Data ---
        origin_col = 'Source' # City column in flight data to match weather data
        # Add placeholder columns first
        df['Weather_Temp_C'] = np.nan
        df['Weather_Humidity'] = np.nan
        df['Weather_Condition'] = 'Unknown'

        if not weather_df.empty and all(c in weather_df.columns for c in ['city', 'datetime', 'temp_c', 'humidity', 'condition']) and origin_col in df.columns:
            # Aggregate weather to daily level (mean temp/humidity, mode condition)
            weather_daily = weather_df.groupby(
                ['city', pd.Grouper(key='datetime', freq='D')]
            ).agg(
                Weather_Temp_C=('temp_c', 'mean'),
                Weather_Humidity=('humidity', 'mean'),
                Weather_Condition=('condition', lambda x: x.mode()[0] if not x.mode().empty else 'Unknown') # Handle empty modes
            ).reset_index()
            # Rename columns to match merge keys
            weather_daily = weather_daily.rename(columns={'city': origin_col, 'datetime': 'Date_Norm'})
            weather_daily['Date_Norm'] = weather_daily['Date_Norm'].dt.normalize() # Ensure date part only

            # Perform the merge
            df = pd.merge(df, weather_daily, on=['Date_Norm', origin_col], how='left', suffixes=('', '_weather'))

             # Overwrite placeholder columns with merged weather data
            if 'Weather_Temp_C_weather' in df.columns:
                df['Weather_Temp_C'] = df['Weather_Temp_C_weather'].combine_first(df['Weather_Temp_C'])
                df.drop('Weather_Temp_C_weather', axis=1, inplace=True)
            if 'Weather_Humidity_weather' in df.columns:
                df['Weather_Humidity'] = df['Weather_Humidity_weather'].combine_first(df['Weather_Humidity'])
                df.drop('Weather_Humidity_weather', axis=1, inplace=True)
            if 'Weather_Condition_weather' in df.columns:
                # Fill NaNs in original with merged, then fill remaining NaNs with 'Unknown'
                df['Weather_Condition'] = df['Weather_Condition_weather'].combine_first(df['Weather_Condition']).fillna('Unknown')
                df.drop('Weather_Condition_weather', axis=1, inplace=True)

            logging.debug(f"  Merged daily aggregated weather data on '{origin_col}'.")
        else:
            if not weather_df.empty: logging.warning("  Skipping weather merge: Required columns missing or base df lacks '{origin_col}'.")


        # --- Add Fallback Fuel Price ---
        # This adds a constant value as real fuel price data was not usable/available
        df['FuelPrice'] = self.fuel_price_mean

        # Clean up helper columns
        df.drop(['Date_Norm'], axis=1, inplace=True, errors='ignore')
        # Drop any other temporary merge columns (like '_spec') if they still exist
        df.drop([col for col in df.columns if '_spec' in col], axis=1, inplace=True, errors='ignore')


        return df

    def _get_season(self, month: int) -> str:
        """Maps month number to meteorological season (Northern Hemisphere)."""
        if month in [12, 1, 2]: return 'Winter'
        if month in [3, 4, 5]: return 'Spring'
        if month in [6, 7, 8]: return 'Summer'
        if month in [9, 10, 11]: return 'Fall'
        return 'Unknown' # Handle potential NaNs or invalid months

    def _engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Creates time-based features and calculates rolling/lag features."""
        logging.debug("Engineering time-based features (DayOfWeek, Season, etc.)...")
        # --- Time Features ---
        df['DayOfWeek'] = df['Date'].dt.dayofweek # Monday=0, Sunday=6
        df['Month'] = df['Date'].dt.month
        df['WeekOfYear'] = df['Date'].dt.isocalendar().week.astype(int) # ISO week number
        df['DayOfYear'] = df['Date'].dt.dayofyear
        df['IsWeekend'] = df['DayOfWeek'].isin([5, 6]).astype(int) # Flag for Saturday/Sunday
        df['Season'] = df['Month'].apply(self._get_season) # Map month to season

        # --- Rolling Average / Lag Features ---
        # Requires data to be sorted correctly for meaningful results
        logging.debug("Calculating rolling average (MA7) and lag (Lag1) for Price...")
        df = df.sort_values(['Route', 'Airline', 'Date'], ascending=True) # Ensure correct order within groups

        if 'Price' in df.columns:
            # Calculate 7-day rolling mean price per Route-Airline group
            df['Price_MA7'] = df.groupby(['Route', 'Airline'])['Price'].transform(
                lambda x: x.rolling(window=7, min_periods=1).mean() # Need at least 1 period
            )
            # Calculate price from the previous day for the same Route-Airline
            df['Price_Lag1'] = df.groupby(['Route', 'Airline'])['Price'].transform(
                lambda x: x.shift(1) # Shift by 1 day
            )
        else:
            # Add NaN columns if Price is missing, though Price is essential
            logging.warning("  'Price' column not found for MA/Lag calculation.")
            df['Price_MA7'] = np.nan
            df['Price_Lag1'] = np.nan

        return df

    def _handle_missing_values(self, df: pd.DataFrame, is_fitting=False) -> pd.DataFrame:
        """Imputes missing values using stored medians (for transform) or calculates them (for fit)."""
        logging.debug(f"Handling missing values (is_fitting={is_fitting})...")

        # Columns to impute and their strategies (numeric: median, categorical: 'Unknown')
        numeric_cols = [
            'Duration_in_hours', 'Days_left', 'Price', 'Price_MA7', 'Price_Lag1',
            'Weather_Temp_C', 'Weather_Humidity', 'Age',
            'Fuel Consumption (L/hour)', 'Hourly Maintenance Cost ($)'
        ]
        categorical_cols = [
            'Airline', 'Flight_code', 'Class', 'Source', 'Departure', 'Total_stops',
            'Arrival', 'Destination', 'Journey_day', 'Weather_Condition', 'Season',
            'AircraftType' # Impute this if missing, might be needed for merge/encoding
        ]

        # Impute numeric columns with median
        for col in numeric_cols:
            if col in df.columns:
                median_attribute_name = f"{col}_median" # Attribute name to store/retrieve median
                if is_fitting:
                    # Calculate median from current data if fitting
                    median_val = df[col].median()
                    # Store the calculated median on the processor instance
                    setattr(self, median_attribute_name, median_val)
                    if pd.isna(median_val): logging.warning(f"  Median for '{col}' is NaN during fit.")
                else:
                    # Retrieve the stored median if transforming, default to 0 if not found
                    median_val = getattr(self, median_attribute_name, 0)

                # Perform imputation using the determined median value
                # Ensure median_val itself is not NaN before filling
                fill_value = median_val if pd.notna(median_val) else 0
                df[col] = df[col].fillna(fill_value)
            else:
                logging.debug(f"  Column '{col}' not found for numeric imputation.")

        # Impute categorical columns with 'Unknown'
        for col in categorical_cols:
             if col in df.columns:
                 df[col] = df[col].fillna('Unknown') # Use a consistent placeholder string
             else:
                 logging.debug(f"  Column '{col}' not found for categorical imputation.")

        # Final check for any remaining NaNs
        remaining_nans = df.isnull().sum()
        remaining_nans = remaining_nans[remaining_nans > 0]
        if not remaining_nans.empty:
            logging.warning(f"  NaNs still remain after imputation in columns: {remaining_nans.index.tolist()}")
            # Consider logging the count: {remaining_nans.to_dict()}

        return df

    def _fit_categorical_encoders(self, df: pd.DataFrame):
        """Fits LabelEncoder and OrdinalEncoder instances on the data."""
        logging.debug("Fitting categorical encoders...")

        # Dictionary mapping feature names to their corresponding LabelEncoder instances
        label_encoders_map = {
            'Route': self.route_encoder, 'Airline': self.airline_encoder, 'Class': self.class_encoder,
            'Source': self.source_encoder, 'Destination': self.destination_encoder, 'Season': self.season_encoder,
            'Weather_Condition': self.weather_cond_encoder, 'AircraftType': self.aircraft_model_encoder
        }

        # Fit each LabelEncoder
        for name, encoder in label_encoders_map.items():
            series = df.get(name)
            if series is not None and not series.empty:
                 try:
                     # Include 'Unknown' explicitly in case it only appears in test/validation
                     unique_values = np.append(series.astype(str).unique(), 'Unknown')
                     encoder.fit(unique_values)
                     logging.debug(f"  Fitted LabelEncoder for '{name}' with {len(encoder.classes_)} classes.")
                 except Exception as e:
                     logging.error(f"  Error fitting LabelEncoder for '{name}': {e}")
            else:
                 logging.warning(f"  Cannot fit LabelEncoder for '{name}', data series is missing or empty.")

        # --- Explicitly Fit Ordinal Encoders ---
        # Although categories are predefined, calling fit ensures consistency and attribute access.
        ordinal_encoders_map = {
            'Total_stops': self.stops_encoder,
            'Journey_day': self.journey_day_encoder,
            'Departure': self.dep_time_encoder,
            'Arrival': self.arr_time_encoder
        }

        for name, encoder in ordinal_encoders_map.items():
            series = df.get(name)
            if series is not None and not series.empty:
                try:
                    # Fit requires a 2D array, so reshape the series
                    encoder.fit(series.astype(str).values.reshape(-1, 1))
                    # Now access categories_ safely after fitting
                    logging.debug(f"  Fitted OrdinalEncoder for '{name}'. Categories used: {encoder.categories_}")
                except Exception as e:
                    logging.error(f"  Error fitting OrdinalEncoder for '{name}': {e}")
            else:
                logging.warning(f"  Cannot fit OrdinalEncoder for '{name}', data series is missing or empty.")
                                
        # Log their categories for confirmation.
        logging.debug(f"  Using predefined OrdinalEncoder for Stops: {self.stops_encoder.categories_}")
        logging.debug(f"  Using predefined OrdinalEncoder for Journey Day: {self.journey_day_encoder.categories_}")
        logging.debug(f"  Using predefined OrdinalEncoder for Departure Time: {self.dep_time_encoder.categories_}")
        logging.debug(f"  Using predefined OrdinalEncoder for Arrival Time: {self.arr_time_encoder.categories_}")

    def _apply_encoders(self, df: pd.DataFrame) -> pd.DataFrame:
        """Applies the fitted encoders to transform categorical columns."""
        logging.debug("Applying categorical encoders...")

        # Map feature names to their fitted encoders and the desired output column name
        encoders_to_apply = {
            'Route': (self.route_encoder, 'Route_encoded'), 'Airline': (self.airline_encoder, 'Airline_encoded'),
            'Class': (self.class_encoder, 'Class_encoded'), 'Source': (self.source_encoder, 'Source_encoded'),
            'Destination': (self.destination_encoder, 'Destination_encoded'), 'Season': (self.season_encoder, 'Season_encoded'),
            'Weather_Condition': (self.weather_cond_encoder, 'WeatherCond_encoded'), # Shortened name
            'AircraftType': (self.aircraft_model_encoder, 'AircraftModel_encoded'), # Using AircraftType as source
            'Total_stops': (self.stops_encoder, 'Stops_encoded'), # Shortened name
            'Journey_day': (self.journey_day_encoder, 'JourneyDay_encoded'), # Shortened name
            'Departure': (self.dep_time_encoder, 'Departure_encoded'),
            'Arrival': (self.arr_time_encoder, 'Arrival_encoded')
        }

        # Apply each encoder
        for source_col, (encoder, target_col) in encoders_to_apply.items():
            series = df.get(source_col)
            # Use helper functions for safe transformation
            if isinstance(encoder, LabelEncoder):
                df[target_col] = self._safe_transform_le(encoder, series, source_col)
            elif isinstance(encoder, OrdinalEncoder):
                df[target_col] = self._safe_transform_ord(encoder, series, source_col)
            else:
                 logging.warning(f"Unknown encoder type for column '{source_col}'. Skipping.")
                 df[target_col] = -1 # Assign default error value

        return df

    def _safe_transform_le(self, encoder: LabelEncoder, series: Optional[pd.Series], name: str) -> np.ndarray:
        """Safely transforms a Series using a fitted LabelEncoder, handling unknowns."""
        default_value = -1 # Value for unknown categories
        # Return default if series is missing or encoder not fitted
        if series is None: return np.full(len(series) if series is not None else 0, default_value, dtype=int) # Need length for shape
        if not hasattr(encoder, 'classes_'):
            logging.warning(f"LabelEncoder '{name}' not fitted. Returning {default_value}.")
            return np.full(len(series), default_value, dtype=int)

        series_str = series.astype(str)
        # Create a mask for values known by the encoder
        known_mask = series_str.isin(encoder.classes_)
        # Initialize output array with the default value
        encoded_values = np.full(series_str.shape, default_value, dtype=int)

        try:
            # Transform only the known values
            if known_mask.any():
                encoded_values[known_mask] = encoder.transform(series_str[known_mask])
            # Log if unknowns were encountered
            if (~known_mask).any():
                 logging.debug(f"  Unknown categories encountered in '{name}': {series_str[~known_mask].unique()}")
            return encoded_values
        except Exception as e:
            logging.error(f"  Error applying LabelEncoder for '{name}': {e}")
            return np.full(series_str.shape, default_value, dtype=int) # Return default on error

    def _safe_transform_ord(self, encoder: OrdinalEncoder, series: Optional[pd.Series], name: str) -> np.ndarray:
        """Safely transforms a Series using a fitted OrdinalEncoder."""
        default_value = -1 # Value for unknown categories (should match encoder's setting)
        if series is None: return np.full(len(series) if series is not None else 0, default_value, dtype=int)
        if not hasattr(encoder, 'categories_'):
            logging.warning(f"OrdinalEncoder '{name}' not fitted. Returning {default_value}.")
            return np.full(len(series), default_value, dtype=int)

        try:
            # OrdinalEncoder with handle_unknown='use_encoded_value' handles unknowns internally
            encoded_values = encoder.transform(series.astype(str).values.reshape(-1, 1)).flatten()
            return encoded_values.astype(int)
        except Exception as e:
            # This might catch errors if a value is truly unexpected and not handled
            logging.error(f"  Error applying OrdinalEncoder for '{name}': {e}")
            return np.full(len(series), default_value, dtype=int)

    def _fit_numerical_scalers(self, df: pd.DataFrame):
        """Fits StandardScaler and MinMaxScaler instances on numerical columns."""
        logging.debug("Fitting numerical scalers...")

        # Map features to their corresponding scaler instances
        scalers_map = {
            'Price': self.price_scaler, 'Days_left': self.days_left_scaler,
            'Duration_in_hours': self.duration_scaler, 'Weather_Temp_C': self.temp_scaler,
            'Weather_Humidity': self.humidity_scaler, 'Age': self.age_scaler,
            'Fuel Consumption (L/hour)': self.fuel_eff_scaler,
            'Price_MA7': self.fare_ma7_scaler, 'Price_Lag1': self.fare_lag1_scaler
        }

        # Fit each scaler
        for name, scaler in scalers_map.items():
            series = df.get(name)
            # Check if series exists, is numeric, and has non-NaN values
            if series is not None and pd.api.types.is_numeric_dtype(series):
                series_cleaned = series.dropna()
                if not series_cleaned.empty:
                    try:
                        # Fit scaler on the non-NaN values, reshaped for sklearn input
                        scaler.fit(series_cleaned.values.reshape(-1, 1))
                        logging.debug(f"  Fitted {scaler.__class__.__name__} for '{name}'.")
                    except Exception as e:
                        logging.error(f"  Error fitting scaler for '{name}': {e}")
                else:
                     logging.warning(f"  Cannot fit scaler for '{name}', all values are NaN.")
            else:
                 logging.warning(f"  Cannot fit scaler for '{name}', column missing, empty, or non-numeric.")

        # --- Update Config Price Range ---
        # After fitting the price_scaler, update the config's price_range
        # This ensures the environment uses the actual observed min/max fares.
        if hasattr(self.price_scaler, 'data_min_') and hasattr(self.price_scaler, 'data_max_'):
             # Ensure min/max are valid numbers before updating
             min_price = float(self.price_scaler.data_min_[0])
             max_price = float(self.price_scaler.data_max_[0])
             if pd.notna(min_price) and pd.notna(max_price) and min_price < max_price:
                 self.config.price_range = (min_price, max_price)
                 logging.info(f"  Configuration price_range updated based on fitted Fare data: {self.config.price_range}")
             else:
                 logging.warning("  Could not update config price_range, fitted scaler min/max invalid.")

    def _apply_scalers(self, df: pd.DataFrame) -> pd.DataFrame:
        """Applies the fitted scalers to transform numerical columns."""
        logging.debug("Applying numerical scalers...")

        # Map original feature names to scaler instances and target scaled column names
        scalers_to_apply = {
            'Price': (self.price_scaler, 'Price_scaled'),
            'Days_left': (self.days_left_scaler, 'Days_Left_scaled'), # Renamed target
            'Duration_in_hours': (self.duration_scaler, 'Duration_scaled'), # Renamed target
            'Weather_Temp_C': (self.temp_scaler, 'Weather_Temp_scaled'), # Renamed target
            'Weather_Humidity': (self.humidity_scaler, 'Weather_Humidity_scaled'), # Renamed target
            'Age': (self.age_scaler, 'Age_scaled'),
            'Fuel Consumption (L/hour)': (self.fuel_eff_scaler, 'Fuel_Efficiency_scaled'), # Renamed target
            'Price_MA7': (self.fare_ma7_scaler, 'Price_MA7_scaled'),
            'Price_Lag1': (self.fare_lag1_scaler, 'Price_Lag1_scaled')
        }

        # Apply each scaler
        for source_col, (scaler, target_col) in scalers_to_apply.items():
            series = df.get(source_col)
            # Use helper for safe scaling
            df[target_col] = self._safe_scale(scaler, series, source_col)

        return df

    def _safe_scale(self, scaler, series: Optional[pd.Series], name: str) -> np.ndarray:
        """Safely scales a Series using a fitted scaler, handling NaNs and fitting status."""
        # Return default (array of zeros) if series is missing or non-numeric
        if series is None or not pd.api.types.is_numeric_dtype(series):
             return np.zeros(len(series) if series is not None else 0, dtype=float)

        # Check if the scaler instance is fitted (has learned parameters)
        is_fitted = hasattr(scaler, 'scale_') or hasattr(scaler, 'data_min_')
        if not is_fitted:
            logging.warning(f"Scaler for '{name}' is not fitted. Returning array of zeros.")
            return np.zeros(len(series), dtype=float)

        try:
            # Impute NaNs *before* scaling. Use mean for StandardScaler, min for MinMaxScaler as simple fallbacks.
            if isinstance(scaler, StandardScaler) and hasattr(scaler, 'mean_'):
                impute_value = scaler.mean_[0] if pd.notna(scaler.mean_[0]) else 0.0
            elif isinstance(scaler, MinMaxScaler) and hasattr(scaler, 'data_min_'):
                 impute_value = scaler.data_min_[0] if pd.notna(scaler.data_min_[0]) else 0.0
            else: impute_value = 0.0 # Generic fallback

            series_filled = series.fillna(impute_value)
            # Reshape for sklearn and transform
            scaled_values = scaler.transform(series_filled.values.reshape(-1, 1)).flatten()
            return scaled_values
        except Exception as e:
            logging.error(f"  Error applying scaler for '{name}': {e}")
            # Return array of zeros on error
            return np.zeros(len(series), dtype=float)

    # --- Public Price Scaling Utilities ---

    def scale_prices(self, prices: np.ndarray) -> np.ndarray:
        """
        Scales an array of prices (fares) using the fitted price_scaler (MinMaxScaler to [0, 1]).

        Args:
            prices: A NumPy array of prices.

        Returns:
            A NumPy array of scaled prices (0-1 range).
        """
        if not self.is_fitted or not hasattr(self.price_scaler, 'scale_'):
            raise RuntimeError("Price scaler has not been fitted. Call fit() first.")
        # Use the safe scaling helper method
        # Need to convert numpy array to Series temporarily for the helper
        scaled_prices_array = self._safe_scale(self.price_scaler, pd.Series(prices), "Price (Utility Scale)")
        return scaled_prices_array

    def inverse_scale_prices(self, scaled_prices_01: np.ndarray) -> np.ndarray:
        """
        Inverse transforms an array of scaled prices (in [0, 1] range) back to the original fare range.

        Args:
            scaled_prices_01: A NumPy array of prices scaled between 0 and 1.

        Returns:
            A NumPy array of prices in the original fare scale.
        """
        if not self.is_fitted or not hasattr(self.price_scaler, 'scale_'):
             raise RuntimeError("Price scaler has not been fitted. Call fit() first.")

        # Ensure input is a NumPy array and has the correct shape for inverse_transform
        scaled_prices = np.array(scaled_prices_01).reshape(-1, 1)

        # Clip scaled prices to the expected [0, 1] range before inverse transforming
        # This prevents errors if the input action (e.g., from actor network noise) slightly exceeds bounds.
        scaled_prices_clipped = np.clip(scaled_prices, 0.0, 1.0)

        try:
            # Perform the inverse transformation
            original_prices = self.price_scaler.inverse_transform(scaled_prices_clipped)
            # Return as a flattened 1D array
            return original_prices.flatten()
        except Exception as e:
             logging.error(f"Error during inverse price scaling: {e}")
             # Fallback strategy: return the minimum price from the config range
             return np.full(scaled_prices.shape[0], self.config.price_range[0])

### 2.2 Load, Fit, Transform Data
 *Execute the data processing pipeline.*

In [None]:
data_processor = DataProcessor(config)
raw_data = data_processor._load_data()

# Fit
try:
    data_processor.fit(raw_data)
    print(f"\n‚úÖ DataProcessor Fitted. Price Range: {config.price_range}")
except Exception as e:
    logging.error(f"Data Fitting Failed: {e}", exc_info=True); raise SystemExit("Stop: Fitting Error") from e

# Transform
processed_data = data_processor.transform(raw_data)

if processed_data.empty:
    raise SystemExit("‚ùå Processed data is empty! Check processing steps.")
else:
    print(f"‚úÖ Data Processed. Shape: {processed_data.shape}")

### 2.3 Interactive Data Exploration
 *Explore distributions and data samples interactively.*


In [None]:
# Display Sample of Processed Data
display(HTML("<h4>Processed Data Sample:</h4>"))
display(processed_data.sample(5))

# Interactive Distribution Plot
numeric_cols = processed_data.select_dtypes(include=np.number).columns.tolist()
# Filter out obviously encoded columns for better default view
numeric_cols_plot = [c for c in numeric_cols if not c.endswith('_encoded') and c not in ['IsHoliday', 'IsWeekend']]

@interact
def show_distributions(column=Dropdown(options=numeric_cols_plot, value='Price')):
    plt.figure(figsize=(10, 4))
    sns.histplot(processed_data[column].dropna(), kde=True, bins=50) # dropna for safety
    plt.title(f'Distribution of {column}')
    plt.xlabel(column)
    plt.ylabel('Frequency')
    plt.show()

## 3. Reinforcement Learning Environment
---
*Visual environment diagram (placeholder).* Define the environment dynamics.


### 3.1 Environment Diagram (Conceptual)
*Illustrates the Agent-Environment interaction loop.* 

### 3.2 Environment Class (`AirlinePricingEnv`)
 *Defines the simulation logic.*

In [None]:
class AirlinePricingEnv:
    """
    Simulates the airline pricing environment for Reinforcement Learning.

    The environment takes actions representing desired price levels (scaled),
    simulates demand and costs based on heuristics, calculates profit,
    and provides the next state and reward to the agent.

    Key Attributes:
        data (pd.DataFrame): Preprocessed historical flight data.
        data_processor (DataProcessor): Instance used for data transformations (esp. price scaling).
        config (Config): Configuration object with environment parameters.
        device (torch.device): CPU or GPU device for tensor operations.
        n_flights (int): Number of unique Route-Airline combinations being managed.
        state_size (int): Dimension of the state vector provided to the agent.
        action_size (int): Dimension of the action vector (equals n_flights).
        current_prices (np.ndarray): Current prices set for each flight.
        seats_sold (np.ndarray): Number of seats sold for each flight so far in the episode.
        current_date (pd.Timestamp): Current simulation day.
    """

    def __init__(self, processed_data: pd.DataFrame, data_processor: DataProcessor, config: Config, device: torch.device):
        """
        Initializes the Airline Pricing Environment.

        Args:
            processed_data: DataFrame containing preprocessed features (output from DataProcessor.transform).
            data_processor: Fitted DataProcessor instance (needed for price scaling & stats).
            config: Configuration object containing simulation parameters (capacity, length, etc.).
            device: The torch device (CPU or GPU) to use for state tensors.
        """
        if processed_data.empty:
            raise ValueError("Processed data cannot be empty for environment initialization.")

        # --- Store Core Components ---
        # Sort data for potentially easier lookups later, though data_by_date is primary
        self.data = processed_data.sort_values(by=['Route', 'Airline', 'Date']).reset_index(drop=True)
        self.data_processor = data_processor
        self.config = config
        self.device = device
        logging.info("Initializing AirlinePricingEnv...")

        # --- Environment Parameters ---
        self.seats_capacity = self.config.seats_capacity
        self.simulation_length_days = self.config.simulation_length_days
        # Price range is crucial and should be set by DataProcessor during fit
        self.price_range = self.config.price_range
        # Pre-calculate price range delta for efficient action scaling
        self.price_delta = self.price_range[1] - self.price_range[0]
        if self.price_delta <= 0:
             logging.warning(f"Invalid price range detected ({self.price_range}). Check data processor fitting.")
             # Provide a fallback range if needed
             # self.price_range = (1000.0, 50000.0)
             # self.price_delta = self.price_range[1] - self.price_range[0]
        logging.info(f"  Environment using Price Range: [{self.price_range[0]:.2f}, {self.price_range[1]:.2f}] (Delta: {self.price_delta:.2f})")

        # --- Identify Unique Flights ---
        # A "flight" in this context is a unique Route-Airline combination
        if 'Route' not in self.data.columns or 'Airline' not in self.data.columns:
            raise ValueError("Processed data must contain 'Route' and 'Airline' columns.")
        # Get unique pairs and create a mapping for easy indexing
        self.flights = self.data[['Route', 'Airline']].drop_duplicates().to_records(index=False)
        self.n_flights = len(self.flights)
        if self.n_flights == 0: raise ValueError("No unique flights (Route-Airline pairs) found in the processed data.")
        # Create a dictionary: {(Route, Airline): index}
        self.flight_map = {(route, airline): i for i, (route, airline) in enumerate(self.flights)}
        logging.info(f"  Found {self.n_flights} unique flights (Route-Airline pairs).")

        # --- Simulation Time Tracking ---
        self.start_date = self.data['Date'].min()
        # Calculate end date based on simulation length
        self.end_date = self.start_date + pd.Timedelta(days=self.simulation_length_days - 1)
        logging.info(f"  Simulation period: {self.start_date.date()} to {self.end_date.date()} ({self.simulation_length_days} days)")
        self.current_step: int = 0          # Counter for steps within an episode
        self.current_date: pd.Timestamp = self.start_date # Current simulation day

        # --- Dynamic Environment State ---
        # Track seats sold for each flight within the current episode
        self.seats_sold = np.zeros(self.n_flights, dtype=int)
        # Track the price set by the agent for each flight on the *current* day
        # Initialize prices using the global historical mean fare calculated by DataProcessor
        initial_price = self.data_processor.global_fare_mean
        # Ensure initial price is within the allowed range
        initial_price = np.clip(initial_price if pd.notna(initial_price) else self.price_range[0],
                                self.price_range[0], self.price_range[1])
        self.current_prices = np.full(self.n_flights, fill_value=initial_price, dtype=float)

        # --- Precomputed Data Structures for Efficiency ---
        # Group data by normalized date for quick lookups in the step function
        self.data_by_date: Dict[pd.Timestamp, pd.DataFrame] = {
            date.normalize(): group
            for date, group in self.data.groupby(self.data['Date'].dt.normalize())
        }
        # Store historical fare statistics (mean/median per route/airline) from DataProcessor
        self.hist_fare_stats = self.data_processor.historical_fare_stats
        # Store global fallback mean fare
        self.global_fare_mean = self.data_processor.global_fare_mean

        # --- Determine State and Action Dimensions ---
        # State size depends on the features defined in _determine_state_size
        self.state_size: int = self._determine_state_size()
        # Action size is the number of flights, as the agent sets one price per flight
        self.action_size: int = self.n_flights

        logging.info(f"  Environment State Size: {self.state_size}")
        logging.info(f"  Environment Action Size: {self.action_size} (one price per flight)")
        if self.state_size == 0:
             raise ValueError("State size calculation failed. Check required features and data processing.")

    def _determine_state_size(self) -> int:
        """
        Determines the size of the state vector based on explicitly defined features.
        Checks if required features exist in the processed data.

        Returns:
            The calculated total dimension of the state vector.
        """
        logging.debug("Determining state vector size and features...")
        # Define features EXPLICITLY included in the state vector
        # Part 1: Features shared across all flights for a given day
        shared_features = [
            'DayOfWeek', 'Month', 'WeekOfYear', 'DayOfYear', 'IsWeekend', 'IsHoliday',
            'FuelPrice' # Using the constant fallback value
        ]
        # Part 2: Features specific to each flight
        flight_features = [
            # Internal dynamic state (calculated within the environment)
            'Seats_Left_scaled',      # Proportion of seats remaining (0-1)
            'Current_Price_scaled',   # Agent's previous price, scaled (0-1) by historical range
            # Data-driven features (from processed_data)
            'Days_Left_scaled',         # Days until departure, scaled (0-1)
            'Duration_scaled',        # Flight duration, standardized
            'Price_MA7_scaled',       # 7-day rolling avg historical price, standardized
            'Price_Lag1_scaled',      # Previous day's historical price, standardized
            # Encoded categorical features
            'Route_encoded',          # Integer ID for route
            'Airline_encoded',        # Integer ID for airline
            'Class_encoded',          # Integer ID for class
            'Stops_encoded',          # Integer ID for number of stops (ordinal)
            'Source_encoded',         # Integer ID for origin city
            'Destination_encoded',    # Integer ID for destination city
            'Season_encoded',         # Integer ID for season
            'JourneyDay_encoded',     # Integer ID for day name (ordinal)
            'Departure_encoded',      # Integer ID for departure time block (ordinal)
            'Arrival_encoded',        # Integer ID for arrival time block (ordinal)
            # Weather features (merged and scaled/encoded)
            'Weather_Temp_scaled',    # Avg daily temp at origin, standardized
            'Weather_Humidity_scaled',# Avg daily humidity at origin, standardized
            'WeatherCond_encoded',    # Integer ID for weather condition text
            # Aircraft features (merged and scaled/encoded)
            'Aircraft_Age_scaled',    # Aircraft age, standardized
            'Fuel_Efficiency_scaled', # Fuel consumption rate, standardized
            'AircraftModel_encoded'   # Integer ID for aircraft model/type
        ]

        # --- Verification ---
        # Check if all required data-driven features are present in the processed DataFrame
        all_required_data_cols = [f for f in shared_features + flight_features
                                  if f not in ['Seats_Left_scaled', 'Current_Price_scaled', 'FuelPrice']] # Exclude internally generated state
        missing_cols = [col for col in all_required_data_cols if col not in self.data.columns]
        if missing_cols:
            logging.error(f"CRITICAL ERROR: State features missing from processed data columns: {missing_cols}")
            logging.error(f"Available columns are: {self.data.columns.tolist()}")
            return 0 # Indicate failure

        num_shared = len(shared_features)
        num_flight_specific = len(flight_features)
        determined_size = num_shared + self.n_flights * num_flight_specific

        # Store feature names lists for use in state construction methods
        self._shared_feature_names = shared_features
        self._flight_feature_names = flight_features
        # Separate internal vs data-driven features for clarity in _get_state_for_flight
        self._flight_internal_state_features = ['Seats_Left_scaled', 'Current_Price_scaled']
        self._flight_data_features = [f for f in flight_features if f not in self._flight_internal_state_features]
        logging.debug(f"  State size determined: {determined_size} ({num_shared} shared + {self.n_flights} flights * {num_flight_specific} flight-specific)")
        return determined_size

    def _get_shared_state(self, date: pd.Timestamp) -> np.ndarray:
        """
        Constructs the shared part of the state vector for a given date.

        Args:
            date: The current simulation timestamp.

        Returns:
            A NumPy array containing the shared state features.
        """
        norm_date = date.normalize() # Use normalized date for lookup
        day_data = self.data_by_date.get(norm_date) # Retrieve pre-grouped data for the day

        # Initialize state vector with zeros
        shared_state_values = np.zeros(len(self._shared_feature_names), dtype=np.float32)

        # Get index of FuelPrice to set the fallback value correctly
        try: fuel_idx = self._shared_feature_names.index('FuelPrice')
        except ValueError: fuel_idx = -1 # Should not happen if defined above

        # Handle cases where data for the specific date might be missing
        if day_data is None or day_data.empty:
            logging.log(logging.DEBUG - 1, f"No historical data found for date {norm_date}. Using defaults for shared state.")
            # Set the fallback fuel price if the feature is included
            if fuel_idx != -1: shared_state_values[fuel_idx] = self.data_processor.fuel_price_mean
            return shared_state_values

        # Use the first row of the day's data for shared features (they should be constant for the day)
        row = day_data.iloc[0]
        # Populate the state vector based on defined feature names
        for i, fname in enumerate(self._shared_feature_names):
            if fname == 'FuelPrice':
                shared_state_values[i] = self.data_processor.fuel_price_mean # Use constant fallback
            else:
                # Use .get(fname, 0) for safety, defaulting to 0 if a feature is unexpectedly missing
                shared_state_values[i] = row.get(fname, 0.0)

        return shared_state_values

    def _get_state_for_flight(self, flight_index: int, date: pd.Timestamp) -> np.ndarray:
        """
        Constructs the flight-specific part of the state vector for a single flight on a given date.

        Args:
            flight_index: The index of the flight (in self.flights).
            date: The current simulation timestamp.

        Returns:
            A NumPy array containing the flight-specific state features.
        """
        # Get route and airline identifier for this flight index
        route, airline = self.flights[flight_index]
        # Initialize the state vector part with zeros
        flight_state_values = np.zeros(len(self._flight_feature_names), dtype=np.float32)

        # --- 1. Get Internal (Dynamic) State Features ---
        # Seats Left: Calculate remaining seats and scale by capacity
        seats_left = self.seats_capacity - self.seats_sold[flight_index]
        seats_left_scaled = seats_left / self.seats_capacity if self.seats_capacity > 0 else 0.0
        # Current Price: Use the price set in the *previous* step, scaled 0-1
        current_price_scaled_01 = self.data_processor.scale_prices(np.array([self.current_prices[flight_index]]))[0]

        # Populate the internal state features in the vector
        try:
            # Find indices based on predefined names
            sl_idx = self._flight_feature_names.index('Seats_Left_scaled')
            cp_idx = self._flight_feature_names.index('Current_Price_scaled')
            flight_state_values[sl_idx] = seats_left_scaled
            flight_state_values[cp_idx] = current_price_scaled_01
        except ValueError as e:
            # This indicates a mismatch between defined features and the index lookup
            logging.error(f"State feature name mismatch error: {e}. Check _determine_state_size.")

        # --- 2. Get Data-Driven State Features ---
        norm_date = date.normalize() # Use normalized date for lookup
        day_data = self.data_by_date.get(norm_date) # Get data for the current day

        # If no data exists for the day, return state with zeros for data features
        if day_data is None:
            logging.log(logging.DEBUG - 1, f"No historical data for date {norm_date}, flight {flight_index}. Using zeros for data features.")
            return flight_state_values

        # Find the specific row(s) for this flight (Route-Airline) on this date
        flight_data_rows = day_data[(day_data['Route'] == route) & (day_data['Airline'] == airline)]

        # If no specific data for this flight on this day, return state with zeros
        if flight_data_rows.empty:
            logging.log(logging.DEBUG - 1, f"No specific historical data for flight {route}-{airline} on {norm_date}. Using zeros for data features.")
            return flight_state_values

        # Use the first matching row if multiple exist (shouldn't happen often with daily data)
        row = flight_data_rows.iloc[0]

        # Populate the remaining features using data from the matched row
        for fname in self._flight_data_features: # Iterate through data-driven feature names
             try:
                  # Find the index in the full flight feature list
                  f_idx = self._flight_feature_names.index(fname)
                  # Get value from row, defaulting to 0.0 if column missing from row (shouldn't happen after processing)
                  flight_state_values[f_idx] = row.get(fname, 0.0)
             except ValueError:
                 # This feature name wasn't found in the main list - indicates definition error
                 logging.error(f"Feature '{fname}' not found in _flight_feature_names list.")
                 pass # Continue, but state will be incomplete/incorrect

        return flight_state_values

    def _get_state_vector(self) -> torch.Tensor:
        """
        Constructs the full, flattened state vector for the current simulation date
        by combining shared and all flight-specific states. Performs validation.

        Returns:
            A PyTorch tensor representing the complete state.
        """
        date = self.current_date # Use the environment's current date

        # Get the shared part of the state
        shared_state = self._get_shared_state(date)
        # Get the specific state part for each flight
        flight_states_list = [self._get_state_for_flight(i, date) for i in range(self.n_flights)]

        # --- Concatenation ---
        # Ensure components are NumPy arrays before combining
        shared_state_np = np.asarray(shared_state, dtype=np.float32)
        # Flatten the list of flight state arrays into a single 1D array
        flight_states_np_flat = np.asarray(flight_states_list, dtype=np.float32).flatten()

        # Concatenate shared state + flattened flight states
        full_state = np.concatenate([shared_state_np, flight_states_np_flat])

        # --- Validation and Sanitization ---
        # 1. Size Check: Ensure constructed state size matches expected size
        if len(full_state) != self.state_size:
            logging.warning(f"Constructed state size mismatch! Expected {self.state_size}, got {len(full_state)}. Padding/Truncating.")
            # Create a correctly sized zero array and copy data into it
            correct_state = np.zeros(self.state_size, dtype=np.float32)
            limit = min(len(full_state), self.state_size)
            correct_state[:limit] = full_state[:limit]
            full_state = correct_state

        # 2. NaN/Infinity Check: Replace any invalid numerical values
        if np.isnan(full_state).any() or np.isinf(full_state).any():
            logging.warning(f"NaN/Inf detected in state vector for date {date}. Imputing with 0/large values.")
            full_state = np.nan_to_num(full_state, nan=0.0, posinf=1e6, neginf=-1e6) # Replace NaN with 0, Inf with large numbers

        # Convert the final NumPy array to a PyTorch tensor on the specified device
        state_tensor = torch.tensor(full_state, dtype=torch.float32, device=self.device)
        return state_tensor

    def reset(self) -> torch.Tensor:
        """
        Resets the environment to its initial state for the start of a new episode.

        Returns:
            The initial state vector as a PyTorch tensor.
        """
        logging.debug("Resetting environment state for new episode...")
        # Reset step counter and simulation date
        self.current_step = 0
        self.current_date = self.start_date # Reset to the beginning of the simulation period

        # Reset dynamic state arrays
        self.seats_sold.fill(0) # Reset seats sold for all flights
        # Reset current prices to the initial calculated value (e.g., global mean fare)
        initial_price = self.global_fare_mean
        initial_price = np.clip(initial_price if pd.notna(initial_price) else self.price_range[0],
                                self.price_range[0], self.price_range[1])
        self.current_prices.fill(initial_price)

        logging.debug(f"Environment reset to date {self.current_date.date()}.")
        # Return the state vector corresponding to the reset state
        return self._get_state_vector()

    # --------------------------------------------------------------------------
    # Demand and Cost Heuristics (Internal Simulation Logic)
    # --------------------------------------------------------------------------

    def _calculate_demand(self, flight_index: int, price: float) -> int:
        """
        (HEURISTIC MODEL) Calculates estimated demand for a specific flight at a given price.
        This function uses simplified rules and does not rely on real historical demand data.

        Args:
            flight_index: Index of the flight.
            price: The price set for the flight for the current day.

        Returns:
            An estimated integer demand value.
        """
        route, airline = self.flights[flight_index]
        flight_key = (route, airline) # Key for looking up historical stats
        date = self.current_date
        norm_date = date.normalize()

        # --- 1. Base Potential ---
        # Arbitrary starting point for demand potential
        base_potential = 50.0 # Example: Start with a base of 50 interested customers

        # --- 2. Get Contextual Data ---
        # Retrieve historical data row for this flight on this day, if available
        day_data = self.data_by_date.get(norm_date)
        flight_row = None
        if day_data is not None:
             rows = day_data[(day_data['Route'] == route) & (day_data['Airline'] == airline)]
             if not rows.empty: flight_row = rows.iloc[0]

        # --- 3. Apply Demand Factors (Multipliers) ---
        # Price Effect: Demand decreases relative to historical average price
        hist_stats = self.hist_fare_stats.get(flight_key) # Get precomputed mean/median
        # Use mean if available, else fallback to global mean
        hist_avg = hist_stats['mean'] if hist_stats and pd.notna(hist_stats.get('mean')) else self.global_fare_mean
        hist_avg = self.global_fare_mean if hist_avg <= 0 else hist_avg # Ensure positive reference price
        price_ratio = price / hist_avg # How does current price compare to average?
        price_sensitivity = 0.7 # Parameter controlling price effect steepness
        # Exponential decay-like effect: demand drops as price exceeds historical avg
        price_effect = 1.0 / (1.0 + price_sensitivity * max(0.0, price_ratio - 1.0))

        # Days Left Effect: Demand increases closer to departure (urgency)
        # Get days_left from the data row or estimate from current step
        if flight_row is not None and 'Days_left' in flight_row:
             days_left = flight_row['Days_left']
        else: days_left = max(1, self.simulation_length_days - self.current_step) # Estimate if missing
        days_left_urgency = 1.2 # Maximum multiplier (e.g., 20% higher demand on day 0)
        # Non-linear effect: demand increases faster closer to the end
        days_left_effect = 1.0 + (days_left_urgency - 1.0) * (1.0 - (days_left / self.simulation_length_days))**2

        # Class/Stops Effect: Simple adjustments based on flight characteristics
        class_effect = 1.0
        stops_effect = 1.0
        if flight_row is not None:
             # Assume Business class has lower volume demand (adjust multiplier as needed)
             if flight_row.get('Class') == 'Business': class_effect = 0.5
             # Assume demand decreases with more stops
             stops = flight_row.get('Total_stops', 'non-stop')
             if stops == '1-stop': stops_effect = 0.7
             elif stops == '2-stops': stops_effect = 0.4
             # Add more rules for '3-stops' etc. if needed

        # --- 4. Combine Factors ---
        estimated_demand = base_potential * price_effect * days_left_effect * class_effect * stops_effect

        # --- 5. Add Stochasticity ---
        # Introduce random noise to make demand less predictable
        noise_factor = np.random.normal(loc=1.0, scale=0.25) # +/- 25% variation around the estimate
        final_demand = estimated_demand * noise_factor

        # --- 6. Constraints ---
        # Ensure demand is non-negative and an integer
        final_demand = max(0, round(final_demand))

        # Optional: Log the calculation details for debugging
        logging.log(logging.DEBUG - 1, # Use lower debug level
                    f"Demand Calc ({route}-{airline}, {date.date()}, P:{price:.0f}): "
                    f"Base={base_potential:.1f}, P_Eff={price_effect:.2f}, Days_Eff={days_left_effect:.2f}, "
                    f"Cls_Eff={class_effect:.1f}, Stp_Eff={stops_effect:.1f}, Noise={noise_factor:.2f} -> Demand={final_demand}")

        return int(final_demand)


    def _calculate_operational_cost(self, flight_index: int) -> float:
        """
        (HEURISTIC MODEL) Calculates the approximate operational cost *per seat* for a flight.
        Uses merged aircraft specs and fallback values.

        Args:
            flight_index: Index of the flight.

        Returns:
            Estimated cost per seat for the flight.
        """
        route, airline = self.flights[flight_index]
        date = self.current_date
        norm_date = date.normalize()

        # --- Get Flight Context (Aircraft Type, Duration) ---
        day_data = self.data_by_date.get(norm_date)
        aircraft_type = 'Unknown' # Default if not found
        duration_hours = 3.0      # Default duration if not found
        if day_data is not None:
             rows = day_data[(day_data['Route'] == route) & (day_data['Airline'] == airline)]
             if not rows.empty:
                 row = rows.iloc[0]
                 # Get AircraftType if available (depends on successful merge in DataProcessor)
                 aircraft_type = row.get('AircraftType', 'Unknown')
                 # Get duration, ensuring it's positive
                 duration_hours = row.get('Duration_in_hours', 3.0)
                 duration_hours = max(0.1, duration_hours) # Prevent zero or negative duration

        # --- Cost Components ---
        # Fuel Cost = Fuel Price * Fuel Rate * Duration
        fuel_price = self.data_processor.fuel_price_mean # Using fallback value
        # Get fuel rate from stored dictionary, use default if model unknown
        fuel_rate_l_per_hr = self.data_processor.aircraft_fuel_rates.get(aircraft_type, 3000) # Example default
        total_fuel_cost = fuel_price * fuel_rate_l_per_hr * duration_hours

        # Maintenance Cost = Hourly Rate * Duration
        maint_rate_usd_per_hr = self.data_processor.maintenance_costs.get(aircraft_type, 1000) # Example default
        total_maint_cost = maint_rate_usd_per_hr * duration_hours

        # Other Fixed Costs (Placeholder for crew, landing fees, etc.)
        other_fixed_costs = 5000 # Example fixed cost per flight

        # --- Total Cost Calculation ---
        total_flight_cost = total_fuel_cost + total_maint_cost + other_fixed_costs

        # --- Cost Per Seat ---
        # Divide total cost by capacity (handle potential division by zero)
        cost_per_seat = total_flight_cost / self.seats_capacity if self.seats_capacity > 0 else total_flight_cost
        # Ensure cost is non-negative
        cost_per_seat = max(0.0, cost_per_seat)

        # --- Cost Capping (Sanity Check) ---
        # Prevent unrealistic scenarios where cost > high price. Cap at e.g., 150% of min possible price.
        max_reasonable_cost = self.price_range[0] * 1.5
        cost_per_seat = min(cost_per_seat, max_reasonable_cost)

        logging.log(logging.DEBUG - 1, # Use lower debug level
                    f"Cost Calc ({route}-{airline}, AC:{aircraft_type}, Dur:{duration_hours:.1f}h): "
                    f"Fuel({fuel_rate_l_per_hr:.0f}L/h*${fuel_price:.1f}/L)={(fuel_price * fuel_rate_l_per_hr * duration_hours):.0f}, "
                    f"Maint(${maint_rate_usd_per_hr:.0f}/h)={(maint_rate_usd_per_hr * duration_hours):.0f}, "
                    f"Fixed={other_fixed} -> Total={total_flight_cost:.0f}, PerSeat={cost_per_seat:.2f}")

        return cost_per_seat

    # --------------------------------------------------------------------------
    # Main Environment Interaction Method
    # --------------------------------------------------------------------------

    def step(self, actions_neg1_to_1: np.ndarray) -> Tuple[torch.Tensor, float, bool, Dict]:
        """
        Executes one time step in the environment based on the agent's actions.
        Accepts actions scaled between -1 and 1 (from Actor's Tanh output).

        Args:
            actions_neg1_to_1: A NumPy array of actions (scaled prices), one for each flight,
                                in the range [-1, 1].

        Returns:
            A tuple containing:
            - next_state (torch.Tensor): The state vector for the next day.
            - reward (float): The total profit accumulated across all flights for the current day.
            - done (bool): Boolean indicating if the simulation episode has ended.
            - info (Dict): Dictionary containing auxiliary information (e.g., profit/demand per flight).
        """
        # --- Action Validation ---
        if len(actions_neg1_to_1) != self.n_flights:
             logging.error(f"Action size mismatch in step: Expected {self.n_flights}, got {len(actions_neg1_to_1)}.")
             # Fallback: Use default action (0 -> middle price) for missing/extra actions
             corrected_actions = np.zeros(self.n_flights)
             limit = min(len(actions_neg1_to_1), self.n_flights)
             corrected_actions[:limit] = actions_neg1_to_1[:limit]
             actions_neg1_to_1 = corrected_actions

        # --- State Initialization for the Step ---
        total_profit_today = 0.0
        # Info dictionary to store step details (useful for logging/analysis)
        step_info = {'profits': {}, 'demands': {}, 'seats_sold': {}, 'prices': {}, 'costs': {}}

        # --- Scale Actions to Prices ---
        # Convert actions from [-1, 1] range to the actual price range [min_price, max_price]
        # Formula: price = min_price + (action + 1) * 0.5 * (max_price - min_price)
        prices_today = self.price_range[0] + (actions_neg1_to_1 + 1.0) * 0.5 * self.price_delta

        # Clip prices *after* scaling to ensure they strictly stay within the defined bounds
        prices_today = np.clip(prices_today, self.price_range[0], self.price_range[1])

        # --- Simulate Each Flight for the Day ---
        for i in range(self.n_flights):
            route, airline = self.flights[i]
            flight_key = (route, airline) # Unique identifier for dictionaries
            price = prices_today[i]

            # Store the actual price set for this flight (for state calculation next step)
            self.current_prices[i] = price

            # --- Simulate Market Response ---
            # Calculate demand based on the set price using the heuristic model
            demand = self._calculate_demand(flight_index=i, price=price)
            # Calculate the operational cost per seat for this flight
            cost_per_seat = self._calculate_operational_cost(flight_index=i)

            # --- Calculate Sales ---
            # Determine available seats for this flight
            available_seats = self.seats_capacity - self.seats_sold[i]
            # Seats sold is the minimum of demand and available seats (cannot sell more than available)
            # Also ensure non-negative values.
            sold_today = min(max(0, demand), max(0, available_seats))

            # --- Calculate Profit ---
            # Profit for this flight = (Revenue per seat - Cost per seat) * Seats sold
            profit = (price - cost_per_seat) * sold_today

            # --- Update Environment State ---
            # Increment seats sold for this flight
            self.seats_sold[i] += sold_today
            # Add this flight's profit to the daily total
            total_profit_today += profit

            # --- Store Information ---
            # Record details for this flight in the info dictionary
            step_info['profits'][flight_key] = profit
            step_info['demands'][flight_key] = demand
            step_info['seats_sold'][flight_key] = sold_today
            step_info['prices'][flight_key] = price
            step_info['costs'][flight_key] = cost_per_seat

        # --- Advance Simulation Time ---
        self.current_step += 1
        self.current_date += pd.Timedelta(days=1)
        # Check if the episode termination condition is met
        done = self.current_date > self.end_date

        # --- Get Next State ---
        # Calculate the state vector for the *next* timestep
        next_state_vector = self._get_state_vector()

        if done:
            logging.debug(f"Episode finished at step {self.current_step}.")

        # --- Define Reward ---
        # The reward for the agent is the total profit achieved across all flights on this day
        reward = total_profit_today

        return next_state_vector, reward, done, step_info

### 3.3 Initialize Environment
 *Create an instance of the environment.*

In [None]:
env = None
if 'processed_data' in locals() and not processed_data.empty:
    try:
        env = AirlinePricingEnv(processed_data, data_processor, config, device)
        print(f"\n‚úÖ Environment initialized successfully. State size: {env.state_size}, Action size: {env.action_size}")
    except Exception as e: print(f"‚ùå Error initializing environment: {e}"); logging.error("Env init failed.", exc_info=True)
else: print("‚ùå Env init failed: Processed data missing.")

## 4. TD3 Agent Definition
 ---
 *Define the Actor, Critic networks, Replay Buffer, and the main TD3 Agent logic.*

### 4.1 Actor Network (`Actor`)
**Maps** a given state observation to a deterministic action (scaled prices in `[-1, 1]`). This represents the agent's current policy.

In [None]:
# -*- coding: utf-8 -*-
# %% [markdown]
# ### 5.1 Actor Network (`Actor`) - with Comments
# **Maps** a given state observation to a deterministic action (scaled prices in `[-1, 1]`). This represents the agent's current policy.

# %%
import torch
import torch.nn as nn
import torch.nn.functional as F # Often used for activations
import logging
from typing import List

class Actor(nn.Module):
    """
    Actor Network for TD3.

    Takes a state representation as input and outputs a deterministic action
    (or vector of actions) bounded within a specified range (typically [-1, 1]
    using tanh activation). This network learns the policy function œÄ(s).
    """
    def __init__(self, state_dim: int, action_dim: int, hidden_dims: List[int], max_action: float):
        """
        Initializes the Actor network layers.

        Args:
            state_dim (int): Dimensionality of the input state space.
            action_dim (int): Dimensionality of the output action space (number of continuous actions).
                               In this case, action_dim = n_flights.
            hidden_dims (List[int]): A list defining the number of neurons in each hidden layer.
                                     Example: [256, 128] creates two hidden layers.
            max_action (float): The maximum absolute value of the action output. Used to scale
                                the output of the final tanh activation layer. Should typically be 1.0
                                if the environment expects actions in [-1, 1].
        """
        super(Actor, self).__init__() # Initialize the parent nn.Module class

        # Store max_action for scaling the output
        self.max_action = max_action
        logging.debug(f"Initializing Actor: StateDim={state_dim}, ActionDim={action_dim}, Hidden={hidden_dims}, MaxAction={max_action}")

        # --- Build Network Layers ---
        layers = []
        # Define the input dimension for the first layer
        input_dim = state_dim

        # Dynamically create hidden layers based on the hidden_dims list
        for hidden_dim in hidden_dims:
            # Linear layer: maps from input_dim to hidden_dim
            layers.append(nn.Linear(input_dim, hidden_dim))
            # Activation function (ReLU is common)
            layers.append(nn.ReLU())
            # Update input_dim for the next layer
            input_dim = hidden_dim
            logging.debug(f"  Added Actor hidden layer: Linear({layers[-2].in_features}, {layers[-2].out_features}), ReLU")


        # --- Output Layer ---
        # Final linear layer maps the last hidden layer's output to the action dimension
        layers.append(nn.Linear(input_dim, action_dim))
        # Tanh activation function: squashes the output to the range [-1, 1]
        # This is standard practice for continuous action spaces in algorithms like TD3/DDPG.
        layers.append(nn.Tanh())
        logging.debug(f"  Added Actor output layer: Linear({layers[-2].in_features}, {layers[-2].out_features}), Tanh")


        # Create the sequential model from the defined layers
        self.network = nn.Sequential(*layers)

        self.apply(self._init_weights)

     # weight initialization method
    def _init_weights(self, module):
         if isinstance(module, nn.Linear):
             # Example using Xavier uniform initialization
             torch.nn.init.xavier_uniform_(module.weight, gain=nn.init.calculate_gain('relu'))
             if module.bias is not None:
                 module.bias.data.fill_(0.01) # Small non-zero bias


    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """
        Defines the forward pass of the Actor network.

        Args:
            state (torch.Tensor): A batch of state observations. Shape: (batch_size, state_dim)

        Returns:
            torch.Tensor: A batch of actions, scaled to the range [-max_action, max_action].
                          Shape: (batch_size, action_dim)
        """
        # Pass the state through the sequential network
        # Output will be in the range [-1, 1] due to the Tanh activation
        tanh_output = self.network(state)

        # Scale the output from [-1, 1] to [-max_action, max_action]
        # This allows flexibility if the environment requires actions outside [-1, 1],
        # although typically max_action is set to 1.0.
        scaled_action = self.max_action * tanh_output

        return scaled_action

### 4.2 Critic Network (`Critic`)
**Estimates** the action-value function `Q(s, a)` for a given state `s` and action `a`. TD3 uses **two** Critic networks (hence "Twin") to mitigate Q-value overestimation. Both Critics have the same architecture but are trained independently.

In [None]:
class Critic(nn.Module):
    """
    Twin Critic Network for TD3.

    Contains two separate Q-network estimators (Q1 and Q2). Each network takes
    both the state and the action as input and outputs a single scalar value
    representing the estimated Q-value Q(s, a). Using the minimum of the two
    Q-values during target calculation helps to reduce overestimation bias.
    """
    def __init__(self, state_dim: int, action_dim: int, hidden_dims: List[int]):
        """
        Initializes the two Critic networks (Q1 and Q2).

        Args:
            state_dim (int): Dimensionality of the input state space.
            action_dim (int): Dimensionality of the input action space.
            hidden_dims (List[int]): A list defining the number of neurons in each hidden layer
                                     for *both* Q1 and Q2 networks. Example: [256, 128].
        """
        super(Critic, self).__init__() # Initialize the parent nn.Module class
        logging.debug(f"Initializing Critic (Twin): StateDim={state_dim}, ActionDim={action_dim}, Hidden={hidden_dims}")

        # --- Network 1 (Q1) ---
        layers1 = []
        # The input dimension for the first layer is the sum of state and action dimensions
        input_dim1 = state_dim + action_dim

        # Build hidden layers for Q1
        for hidden_dim in hidden_dims:
            layers1.append(nn.Linear(input_dim1, hidden_dim))
            layers1.append(nn.ReLU()) # Standard activation
            input_dim1 = hidden_dim # Update input dim for the next layer
            logging.debug(f"  Added Critic Q1 hidden layer: Linear({layers1[-2].in_features}, {layers1[-2].out_features}), ReLU")

        # Output layer for Q1: Outputs a single scalar Q-value
        layers1.append(nn.Linear(input_dim1, 1))
        logging.debug(f"  Added Critic Q1 output layer: Linear({layers1[-1].in_features}, 1)")
        # Create the sequential model for Q1
        self.q1_network = nn.Sequential(*layers1)


        # --- Network 2 (Q2) ---
        # Build Q2 with the *same architecture* but separate weights
        layers2 = []
        input_dim2 = state_dim + action_dim # Reset input dimension

        # Build hidden layers for Q2
        for hidden_dim in hidden_dims:
            layers2.append(nn.Linear(input_dim2, hidden_dim))
            layers2.append(nn.ReLU())
            input_dim2 = hidden_dim
            logging.debug(f"  Added Critic Q2 hidden layer: Linear({layers2[-2].in_features}, {layers2[-2].out_features}), ReLU")

        # Output layer for Q2: Outputs a single scalar Q-value
        layers2.append(nn.Linear(input_dim2, 1))
        logging.debug(f"  Added Critic Q2 output layer: Linear({layers2[-1].in_features}, 1)")
        # Create the sequential model for Q2
        self.q2_network = nn.Sequential(*layers2)

        # Apply shared or separate weight initialization
        self.apply(self._init_weights) # Apply initialization to both networks


    def _init_weights(self, module):
         if isinstance(module, nn.Linear):
             # Example using Kaiming uniform initialization for ReLU
             torch.nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5)) # a=math.sqrt(5) is default for LeakyReLU, adjust if needed
             if module.bias is not None:
                 # Initialize bias based on fan_in
                 fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
                 bound = 1 / math.sqrt(fan_in)
                 torch.nn.init.uniform_(module.bias, -bound, bound)


    def forward(self, state: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Defines the forward pass for both Critic networks.

        Args:
            state (torch.Tensor): A batch of state observations. Shape: (batch_size, state_dim)
            action (torch.Tensor): A batch of actions corresponding to the states.
                                   Shape: (batch_size, action_dim)

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing the Q-value estimates
                                               from both critics: (q1, q2).
                                               Each tensor has shape: (batch_size, 1)
        """
        # Concatenate state and action tensors along the feature dimension (dim=1)
        # This creates the combined input [s, a] for the Q-networks.
        state_action_input = torch.cat([state, action], dim=1)

        # Pass the combined input through each Q-network independently
        q1 = self.q1_network(state_action_input)
        q2 = self.q2_network(state_action_input)

        return q1, q2

    def Q1(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """
        Helper method to get the Q-value estimate only from the first Critic (Q1).
        Used in the Actor loss calculation in TD3.

        Args:
            state (torch.Tensor): A batch of state observations.
            action (torch.Tensor): A batch of actions.

        Returns:
            torch.Tensor: The Q-value estimates from Critic 1. Shape: (batch_size, 1)
        """
        # Concatenate state and action
        state_action_input = torch.cat([state, action], dim=1)
        # Pass through only the Q1 network
        q1 = self.q1_network(state_action_input)
        return q1

### 4.3 Replay Buffer (`TD3Agent`)
**Combines** the Actor, Critics, target networks, replay buffer, and implements the core TD3 learning algorithm logic, including delayed policy updates, target policy smoothing, and clipped double-Q learning.

In [None]:
class TD3Agent:
    """
    Twin Delayed Deep Deterministic Policy Gradient (TD3) Agent.

    Implements the TD3 algorithm, which builds upon DDPG by introducing:
    1. Clipped Double Q-Learning: Uses two Critic networks and takes the minimum
       target Q-value to reduce overestimation bias.
    2. Delayed Policy Updates: Updates the Actor less frequently than the Critics
       to allow Q-value estimates to stabilize.
    3. Target Policy Smoothing: Adds noise to the target Actor's actions during
       target Q-value calculation to smooth the value landscape.
    """
    def __init__(self, state_dim: int, action_dim: int, max_action: float, config: Config, device: torch.device):
        """
        Initializes the TD3 Agent, including networks, optimizers, and replay buffer.

        Args:
            state_dim (int): Dimensionality of the state space.
            action_dim (int): Dimensionality of the action space.
            max_action (float): The maximum absolute value for actions output by the Actor's tanh.
            config (Config): Configuration object containing hyperparameters.
            device (torch.device): The device (CPU or GPU) for tensor operations.
        """
        if state_dim <= 0: raise ValueError(f"Invalid state_dim: {state_dim}")
        self.device = device
        self.action_dim = action_dim
        # Store max_action, used for clipping noise and action selection bounds
        self.max_action = max_action
        # Store relevant hyperparameters from config
        self.gamma = config.gamma          # Discount factor
        self.tau = config.tau            # Soft update factor for target networks
        self.policy_noise = config.policy_noise # Std dev for target policy smoothing noise
        self.noise_clip = config.noise_clip      # Clipping range for target policy noise
        self.policy_freq = config.policy_freq    # Frequency of delayed policy updates
        self.batch_size = config.batch_size      # Training batch size

        logging.info("Initializing TD3 Agent components...")

        # --- Initialize Networks ---
        # Actor Network (outputs actions)
        self.actor = Actor(state_dim, action_dim, config.actor_hidden_dims, max_action).to(device)
        # Target Actor Network (a slow-moving average of the main actor)
        self.actor_target = copy.deepcopy(self.actor) # Initialize target same as main actor
        # Optimizer for the Actor network
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=config.actor_lr)
        logging.info("  Actor networks (main & target) and Adam optimizer created.")

        # Critic Networks (Twin Critics Q1, Q2 estimate Q(s,a))
        self.critic = Critic(state_dim, action_dim, config.critic_hidden_dims).to(device)
        # Target Critic Network (slow-moving average)
        self.critic_target = copy.deepcopy(self.critic) # Initialize target same as main critic
        # Optimizer for *both* Critic networks (parameters are optimized together)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=config.critic_lr)
        logging.info("  Critic networks (main & target) and Adam optimizer created.")

        # --- Initialize Replay Buffer ---
        self.replay_buffer = ReplayBuffer(state_dim, action_dim, config.buffer_size, device)
        logging.info("  Replay buffer created.")

        # --- Training Counter ---
        # Keep track of total training iterations for delayed policy update schedule
        self.total_it: int = 0

        logging.info("TD3 Agent initialized successfully.")


    def select_action(self, state: np.ndarray, exploration_noise: float = 0.0) -> np.ndarray:
        """
        Selects an action based on the current state using the Actor network.
        Adds Gaussian noise for exploration during training.

        Args:
            state (np.ndarray): The current state observation.
            exploration_noise (float): Standard deviation of Gaussian noise to add
                                       for exploration. Set to 0 for deterministic evaluation.

        Returns:
            np.ndarray: The selected action, clipped to the range [-max_action, max_action].
        """
        # 1. Prepare state tensor
        # Convert NumPy state to PyTorch tensor on the correct device
        # Add batch dimension (unsqueeze(0)) as the network expects batches
        state_tensor = torch.tensor(state.reshape(1, -1), dtype=torch.float32, device=self.device)

        # 2. Get action from Actor network
        self.actor.eval() # Set actor to evaluation mode (disables dropout if used)
        with torch.no_grad(): # Disable gradient calculation for inference
            action = self.actor(state_tensor).cpu().numpy().flatten() # Get action, move to CPU, flatten
        self.actor.train() # Set actor back to training mode

        # 3. Add exploration noise (if specified)
        if exploration_noise > 0:
            # Sample noise from a Gaussian distribution
            noise = np.random.normal(0, self.max_action * exploration_noise, size=self.action_dim)
            # Add noise to the deterministic action
            action = action + noise
            logging.log(logging.DEBUG - 1, f"Action with noise {exploration_noise:.2f}: {action}")


        # 4. Clip action to valid range
        # Ensure the final action (with or without noise) stays within [-max_action, max_action]
        clipped_action = np.clip(action, -self.max_action, self.max_action)

        return clipped_action


    def train(self) -> Tuple[Optional[float], Optional[float]]:
        """
        Performs a single TD3 training update step.
        Samples a batch from the replay buffer, calculates losses, and updates
        Actor and Critic networks.

        Returns:
            Tuple[Optional[float], Optional[float]]: A tuple containing the critic loss
            and actor loss for this step (actor loss might be None if it wasn't updated
            due to policy delay). Returns (None, None) if buffer size is insufficient.
        """
        # Increment the total iteration counter
        self.total_it += 1

        # 1. Check if buffer has enough samples for a batch
        if self.replay_buffer.size < self.batch_size:
             logging.debug(f"Skipping training step {self.total_it}: Buffer size ({self.replay_buffer.size}) < Batch size ({self.batch_size})")
             return None, None # Not enough samples to form a batch

        # 2. Sample a mini-batch from the replay buffer
        batch = self.replay_buffer.sample(self.batch_size)
        state = batch['state']
        action = batch['action'] # Action actually taken, stored in buffer
        reward = batch['reward']
        next_state = batch['next_state']
        done = batch['done'] # Done flags (as floats 0.0 or 1.0)

        # --- Critic Loss Calculation and Update ---
        with torch.no_grad(): # Operations inside this block don't track gradients
            # 3. Select next action using the *target* Actor network: a' = pi_target(s')
            next_action = self.actor_target(next_state)

            # 4. Apply Target Policy Smoothing: Add clipped noise to the target action
            # Sample noise from Gaussian distribution, clamp it, add to action
            policy_noise_tensor = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            # Add noise and clip the resulting action to the valid range [-max_action, max_action]
            next_action = (next_action + policy_noise_tensor).clamp(-self.max_action, self.max_action)

            # 5. Compute the target Q-value using Clipped Double-Q Learning:
            # Get Q-values for the noisy next action from *both* target Critic networks
            target_q1, target_q2 = self.critic_target(next_state, next_action)
            # Take the minimum of the two target Q-values to mitigate overestimation
            target_q_min = torch.min(target_q1, target_q2)
            # Calculate the final TD target: y = r + gamma * min(Q1_target, Q2_target) * (1 - done)
            # (1 - done) ensures the target is just 'r' if the state was terminal
            td_target = reward + (1.0 - done) * self.gamma * target_q_min

        # 6. Get current Q estimates from the main Critic networks for the batch's (s, a) pairs
        current_q1, current_q2 = self.critic(state, action)

        # 7. Compute the Critic loss: Mean Squared Error between current Q estimates and the TD target
        # Sum the MSE losses for both critics
        critic_loss = F.mse_loss(current_q1, td_target) + F.mse_loss(current_q2, td_target)
        critic_loss_val = critic_loss.item() # Store scalar value for logging
        logging.log(logging.DEBUG - 1, f"Iter {self.total_it}: Critic Loss={critic_loss_val:.4f}")


        # 8. Optimize the Critic networks
        self.critic_optimizer.zero_grad() # Reset gradients
        critic_loss.backward() # Compute gradients
        # Optional: Clip critic gradients if needed
        # nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=1.0)
        self.critic_optimizer.step() # Update critic weights

        # --- Delayed Actor Loss Calculation and Update ---
        actor_loss_val = None # Initialize actor loss as None for this step
        # Update the Actor and target networks only every 'policy_freq' iterations
        if self.total_it % self.policy_freq == 0:

            # 9. Compute Actor loss (Policy Gradient part)
            # The Actor aims to output actions that maximize the Q-value estimated by Critic 1.
            # Loss is the negative mean Q1 value for the actions the Actor *currently* proposes for the batch states.
            actor_proposed_actions = self.actor(state)
            q1_for_actor_loss = self.critic.Q1(state, actor_proposed_actions) # Use the Q1() helper method
            actor_loss = -q1_for_actor_loss.mean() # Maximize Q1 -> Minimize -Q1
            actor_loss_val = actor_loss.item() # Store scalar value
            logging.log(logging.DEBUG - 1, f"Iter {self.total_it}: Actor Loss={actor_loss_val:.4f} (Policy Update)")


            # 10. Optimize the Actor network
            self.actor_optimizer.zero_grad() # Reset gradients
            actor_loss.backward() # Compute gradients
            # Optional: Clip actor gradients if needed
            # nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=1.0)
            self.actor_optimizer.step() # Update actor weights

            # 11. Soft update the target networks (Actor and Critic)
            self.soft_update_target_networks()

        # Return losses for monitoring purposes
        return critic_loss_val, actor_loss_val


    def soft_update_target_networks(self):
        """
        Performs a soft update of the target network parameters.
        target_weights = tau * local_weights + (1 - tau) * target_weights
        """
        logging.log(logging.DEBUG - 2, "Performing soft target network update...") # Use lower debug level
        # Update Critic target network
        for target_param, local_param in zip(self.critic_target.parameters(), self.critic.parameters()):
            target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)

        # Update Actor target network
        for target_param, local_param in zip(self.actor_target.parameters(), self.actor.parameters()):
            target_param.data.copy_(self.tau * local_param.data + (1.0 - self.tau) * target_param.data)


    def save(self, actor_path: str, critic_path: str):
        """Saves the state dictionaries of the Actor, Critic, and their optimizers."""
        logging.info(f"Saving TD3 agent models -> Actor: {actor_path}, Critic: {critic_path}")
        try:
             # Save Actor state, optimizer state, and training iteration count
             torch.save({
                 'actor_state_dict': self.actor.state_dict(),
                 'actor_optimizer_state_dict': self.actor_optimizer.state_dict(),
                 'total_it': self.total_it # Save training progress
                 }, actor_path)
             # Save Critic state and optimizer state
             torch.save({
                 'critic_state_dict': self.critic.state_dict(),
                 'critic_optimizer_state_dict': self.critic_optimizer.state_dict()
                 }, critic_path)
        except Exception as e:
             logging.error(f"Error saving agent state: {e}", exc_info=True)

    def load(self, actor_path: str, critic_path: str) -> bool:
        """
        Loads the agent's state (networks and optimizers) from specified file paths.
        Also re-synchronizes the target networks.

        Args:
            actor_path (str): Path to the saved Actor checkpoint file.
            critic_path (str): Path to the saved Critic checkpoint file.

        Returns:
            bool: True if both Actor and Critic loaded successfully, False otherwise.
        """
        loaded_actor = False
        loaded_critic = False

        # --- Load Actor ---
        if not os.path.exists(actor_path):
            logging.error(f"Actor checkpoint file not found: {actor_path}")
        else:
            try:
                logging.info(f"Loading Actor state from: {actor_path}")
                checkpoint = torch.load(actor_path, map_location=self.device) # Load to specified device

                self.actor.load_state_dict(checkpoint['actor_state_dict'])
                self.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
                self.total_it = checkpoint.get('total_it', 0) # Load training iterations

                # Important: Re-initialize target network from the loaded actor weights
                self.actor_target = copy.deepcopy(self.actor)
                self.actor.to(self.device) # Ensure loaded model is on device
                self.actor_target.to(self.device)

                # Move optimizer states to the correct device (important if loading CPU model to GPU or vice-versa)
                for state in self.actor_optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor): state[k] = v.to(self.device)
                loaded_actor = True
            except Exception as e:
                logging.error(f"Error loading Actor state from {actor_path}: {e}", exc_info=True)

        # --- Load Critic ---
        if not os.path.exists(critic_path):
            logging.error(f"Critic checkpoint file not found: {critic_path}")
        else:
            try:
                logging.info(f"Loading Critic state from: {critic_path}")
                checkpoint = torch.load(critic_path, map_location=self.device)

                self.critic.load_state_dict(checkpoint['critic_state_dict'])
                self.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])

                # Important: Re-initialize target network from the loaded critic weights
                self.critic_target = copy.deepcopy(self.critic)
                self.critic.to(self.device) # Ensure loaded model is on device
                self.critic_target.to(self.device)

                 # Move optimizer states to the correct device
                for state in self.critic_optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor): state[k] = v.to(self.device)
                loaded_critic = True
            except Exception as e:
                logging.error(f"Error loading Critic state from {critic_path}: {e}", exc_info=True)

        # --- Final Checks ---
        if loaded_actor and loaded_critic:
            logging.info("TD3 Agent state loaded successfully.")
            # Set networks to appropriate modes after loading
            self.actor.train()
            self.critic.train()
            self.actor_target.eval()
            self.critic_target.eval()
            return True
        else:
            logging.error("TD3 Agent loading failed (Actor or Critic or both). Check logs.")
            return False

### 4.5 Initialize TD3 Agent
*Create the agent instance if the environment is ready.*

In [None]:
agent = None
if 'env' in locals() and env is not None and env.state_size > 0:
    try:
        agent = TD3Agent(state_dim=env.state_size, action_dim=env.action_size, max_action=config.max_action_value, config=config, device=device)
        print("\n‚úÖ TD3Agent initialized.")
    except Exception as e: print(f"‚ùå Error initializing TD3Agent: {e}"); logging.error("Agent init failed.", exc_info=True)
else: print("‚ùå Cannot initialize agent: Environment not ready.")

## 5. Agent Training
 ---
*Live training dashboard and interactive controls.* Monitor and manage the training process.

### 5.1 Performance Metrics Tracker
**Tracks** key performance indicators during training and evaluation, such as rewards, losses, profits, prices, and load factors. Provides methods for logging data and plotting progress. Designed for use with the TD3 agent (tracks separate Actor/Critic losses).

In [None]:

        
class PerformanceMetrics:
    """
    Tracks and visualizes performance metrics for the RL agent training process.

    Stores per-step and per-episode data, calculates summary statistics,
    and generates plots to monitor training progress and evaluation results.
    """
    def __init__(self, config: Config, n_flights: int):
        """
        Initializes the PerformanceMetrics tracker.

        Args:
            config: The configuration object containing parameters like price range,
                    capacity, validation frequency.
            n_flights: The number of unique flights managed by the agent/environment.
        """
        self.config = config
        self.n_flights = n_flights
        # Store relevant config parameters for calculations and plotting
        self.price_range: Tuple[float, float] = config.price_range
        self.seats_capacity: int = config.seats_capacity
        logging.info("Initializing PerformanceMetrics tracker.")

        # --- Data Storage ---
        # Use lists to store per-episode summary metrics dynamically
        self.episode_rewards: List[float] = []         # Total reward (profit) per episode
        self.episode_losses_critic: List[float] = []   # Average critic loss per episode
        self.episode_losses_actor: List[float] = []    # Average actor loss per episode (when updated)
        self.episode_profits: List[float] = []         # Same as episode_rewards if reward = profit
        self.episode_avg_prices: List[float] = []      # Average price set across all flights/steps in episode
        self.episode_load_factors: List[float] = []    # Average load factor (% seats sold) per episode
        self.episode_steps: List[int] = []             # Number of simulation steps taken in each episode

        # Store raw data for each step (can become large, consider sampling or limiting if memory is an issue)
        self.all_step_data: List[Dict] = []

        # Store results from periodic evaluation runs
        self.eval_rewards: List[float] = []
    def log_step(self, episode: int, step: int, reward: float,
                 critic_loss: Optional[float], actor_loss: Optional[float], info: Dict):
        """
        Logs data for a single environment step within an episode.

        Args:
            episode (int): The current episode index (0-based).
            step (int): The current step index within the episode (0-based).
            reward (float): The reward received for this step (daily total profit).
            critic_loss (Optional[float]): The critic loss calculated in this step (if any).
            actor_loss (Optional[float]): The actor loss calculated in this step (if any, due to policy delay).
            info (Dict): Additional information dictionary returned by env.step().
        """
        # Create a dictionary summarizing the key data for this step
        step_summary = {
            'episode': episode,
            'step': step,
            'reward': reward,
            # Store losses, using NaN if None (makes averaging easier)
            'critic_loss': critic_loss if critic_loss is not None else np.nan,
            'actor_loss': actor_loss if actor_loss is not None else np.nan,
            # Calculate average price set across all flights for this day
            'avg_price_day': np.mean(list(info.get('prices', {}).values())) if info.get('prices') else np.nan,
            # Calculate total seats sold across all flights for this day
            'total_seats_sold_day': sum(info.get('seats_sold', {}).values())
        }
        # Append the summary to the list of all step data
        self.all_step_data.append(step_summary)

    def log_episode(self, episode_idx: int, total_steps: int):
        """
        Calculates and logs summary metrics at the end of an episode using the stored step data.

        Args:
            episode_idx (int): The index of the just-completed episode (0-based).
            total_steps (int): The total number of steps taken in this episode.
        """
        # Filter the raw step data to get only the data for the specified episode
        # This is more robust than assuming the last N items belong to the episode
        episode_data = [d for d in self.all_step_data if d['episode'] == episode_idx]

        # If no step data was recorded for this episode (e.g., if logging failed), log a warning and return
        if not episode_data:
            logging.warning(f"No step data found for episode {episode_idx + 1}. Cannot log episode metrics.")
            # Append NaN/placeholders to keep list lengths consistent for plotting if needed
            self.episode_rewards.append(np.nan); self.episode_losses_critic.append(np.nan); self.episode_losses_actor.append(np.nan)
            self.episode_profits.append(np.nan); self.episode_avg_prices.append(np.nan); self.episode_load_factors.append(np.nan)
            self.episode_steps.append(total_steps) # Still log the number of steps taken
            return

        # Convert the list of step dictionaries into a pandas DataFrame for easier aggregation
        ep_df = pd.DataFrame(episode_data)

        # Calculate episode summary statistics
        total_reward = ep_df['reward'].sum()
        # Calculate average losses, ignoring NaN values (where loss wasn't computed)
        avg_critic_loss = ep_df['critic_loss'].mean()
        avg_actor_loss = ep_df['actor_loss'].mean() # Will be NaN if actor never updated
        total_profit = total_reward # Assuming reward is defined as profit
        avg_daily_price = ep_df['avg_price_day'].mean() # Average of daily average prices
        total_seats_sold_episode = ep_df['total_seats_sold_day'].sum() # Sum of seats sold each day

        # Calculate total seat capacity offered during the episode
        # Capacity = num_flights * seats_per_flight * num_days_in_episode
        # Use len(ep_df) which is the actual number of steps/days simulated
        total_capacity_episode = self.n_flights * self.seats_capacity * len(ep_df)
        # Calculate load factor (proportion of offered seats that were sold)
        load_factor = total_seats_sold_episode / total_capacity_episode if total_capacity_episode > 0 else 0.0

        # --- Store Episode Summaries ---
        # Append calculated metrics to their respective lists
        self.episode_rewards.append(total_reward)
        # Store 0 if average loss is NaN (e.g., no training steps occurred)
        self.episode_losses_critic.append(avg_critic_loss if pd.notna(avg_critic_loss) else 0.0)
        self.episode_losses_actor.append(avg_actor_loss if pd.notna(avg_actor_loss) else 0.0)
        self.episode_profits.append(total_profit)
        self.episode_avg_prices.append(avg_daily_price if pd.notna(avg_daily_price) else 0.0)
        self.episode_load_factors.append(load_factor)
        self.episode_steps.append(total_steps)

        # Log a summary message to the console for immediate feedback
        log_msg = (f"Ep {episode_idx + 1}: Steps={total_steps}, Profit={total_profit:,.0f}, "
                   f"AvgCritLoss={avg_critic_loss:.4f}, AvgActLoss={avg_actor_loss:.4f}, "
                   f"AvgPrice={avg_daily_price:.0f}, LoadFactor={load_factor:.1%}")
        logging.info(log_msg)

        # Optional: Prune older step data to manage memory usage if `all_step_data` grows too large
        # E.g., keep only data for the last N episodes or implement a max size for `all_step_data`

    def add_eval_result(self, reward: float):
         """
         Stores the average reward obtained from an evaluation run.

         Args:
             reward (float): The average reward over the evaluation episodes.
         """
         self.eval_rewards.append(reward)
         logging.debug(f"Added evaluation result: {reward:.2f}")

    def plot_training_progress(self, fig: plt.Figure, axs: np.ndarray, window: int = 10):
        """
        Updates the provided Matplotlib figure and axes with the latest training metrics.
        Designed to be called periodically (e.g., end of episode) for live dashboard updates.

        Args:
            fig (plt.Figure): The Matplotlib figure object for the dashboard.
            axs (np.ndarray): NumPy array of Matplotlib axes objects (e.g., from plt.subplots).
            window (int): The window size for calculating rolling means (smoothing).
        """
        logging.info("Updating training progress dashboard...")
        if not self.episode_rewards:
            print("No episode data recorded yet to plot.")
            return

        # --- Prepare Data for Plotting ---
        num_episodes = len(self.episode_rewards)
        episodes_axis = np.arange(1, num_episodes + 1) # X-axis for plots (Episode number)

        # Helper function for safe rolling mean calculation (handles NaN/Inf)
        def rolling_mean(data, w):
            series = pd.Series(data).replace([np.inf, -np.inf], np.nan).dropna()
            return series.rolling(w, min_periods=1).mean() if not series.empty else pd.Series()

        # Calculate smoothed versions of the metrics
        profits_smooth = rolling_mean(self.episode_profits, window)
        # Critic loss handling (get series, drop invalid, then smooth)
        critic_loss_series = pd.Series(self.episode_losses_critic).replace([np.inf,-np.inf],np.nan).dropna()
        critic_losses_smooth = critic_loss_series.rolling(window, min_periods=1).mean()
        # Actor loss handling
        actor_loss_series = pd.Series(self.episode_losses_actor).replace([np.inf,-np.inf],np.nan).dropna()
        actor_losses_smooth = actor_loss_series.rolling(window, min_periods=1).mean()
        # Other metrics
        prices_smooth = rolling_mean(self.episode_avg_prices, window)
        load_factor_smooth = rolling_mean(self.episode_load_factors, window)
        steps_smooth = rolling_mean(self.episode_steps, window)

        # --- Clear and Update Plots ---
        # Clear previous plots from all axes to prepare for redraw
        for row_axs in axs:
            for ax in row_axs:
                ax.clear()

        # --- Plot 1: Total Profit ---
        ax = axs[0, 0]
        ax.plot(episodes_axis, self.episode_profits, color='lightblue', alpha=0.6, label='Raw Profit')
        # Plot smoothed line using its own index + 1 to align with episode number
        ax.plot(profits_smooth.index + 1, profits_smooth, color='blue', label=f'Smoothed (w={window})')
        ax.set_ylabel('Total Profit per Episode')
        ax.set_title('Episode Profit')
        ax.legend(fontsize='small'); ax.grid(True)

        # --- Plot 2: Critic and Actor Losses ---
        ax = axs[0, 1]
        # Plot raw critic loss points (use its own index + 1)
        ax.plot(critic_loss_series.index + 1, critic_loss_series, color='lightcoral', alpha=0.4, linestyle='', marker='.', markersize=2, label='Raw Critic Loss')
        # Plot smoothed critic loss
        ax.plot(critic_losses_smooth.index + 1, critic_losses_smooth, color='red', label=f'Smooth Critic (w={window})')
        ax.set_ylabel('Critic Loss (Log Scale)', color='red')
        ax.set_yscale('log') # Log scale is often useful for losses
        ax.tick_params(axis='y', labelcolor='red')
        ax.grid(True, which='both', axis='y', linestyle=':', linewidth=0.5)
        ax.legend(loc='upper left', fontsize='small')
        ax.set_title('Training Losses')

        # Create a secondary y-axis for Actor loss
        ax2 = ax.twinx()
        ax2.plot(actor_loss_series.index + 1, actor_loss_series, color='skyblue', alpha=0.4, linestyle='', marker='.', markersize=2, label='Raw Actor Loss')
        ax2.plot(actor_losses_smooth.index + 1, actor_losses_smooth, color='deepskyblue', label=f'Smooth Actor (w={window})')
        ax2.set_ylabel('Actor Loss', color='deepskyblue') # Note: Actor loss is typically negative (maximizing -Q)
        ax2.tick_params(axis='y', labelcolor='deepskyblue')
        ax2.legend(loc='upper right', fontsize='small')
        # ax2.grid(True, which='both', axis='y', linestyle=':', linewidth=0.5) # Optional secondary grid

        # --- Plot 3: Average Price ---
        ax = axs[1, 0]
        ax.plot(episodes_axis, self.episode_avg_prices, color='lightgreen', alpha=0.6, label='Raw Avg Price')
        ax.plot(prices_smooth.index + 1, prices_smooth, color='green', label=f'Smoothed (w={window})')
        # Add lines indicating the min/max price range from config
        ax.axhline(self.price_range[0], color='grey', linestyle='--', alpha=0.7, label=f'Min Price ({self.price_range[0]:.0f})')
        ax.axhline(self.price_range[1], color='grey', linestyle='--', alpha=0.7, label=f'Max Price ({self.price_range[1]:.0f})')
        ax.set_ylabel('Average Price (Fare)')
        ax.set_title('Average Daily Price (Fare)')
        ax.legend(fontsize='small'); ax.grid(True)

        # --- Plot 4: Load Factor ---
        ax = axs[1, 1]
        ax.plot(episodes_axis, self.episode_load_factors, color='thistle', alpha=0.6, label='Raw Load Factor')
        ax.plot(load_factor_smooth.index + 1, load_factor_smooth, color='purple', label=f'Smoothed (w={window})')
        ax.set_ylabel('Load Factor')
        ax.set_title('Average Load Factor')
        ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f'{y:.1%}')) # Format y-axis as percentage
        ax.set_ylim(bottom=-0.05, top=1.05) # Set bounds slightly outside 0-1
        ax.legend(fontsize='small'); ax.grid(True)

        # --- Plot 5: Episode Steps ---
        ax = axs[2, 0]
        ax.plot(episodes_axis, self.episode_steps, color='navajowhite', alpha=0.6, label='Raw Steps')
        ax.plot(steps_smooth.index + 1, steps_smooth, color='orange', label=f'Smoothed (w={window})')
        ax.set_xlabel('Episode')
        ax.set_ylabel('Number of Steps')
        ax.set_title('Episode Length')
        ax.legend(fontsize='small'); ax.grid(True)

        # --- Plot 6: Evaluation Reward ---
        ax = axs[2, 1]
        if self.eval_rewards:
            # X-axis represents the episode number when evaluation was performed
            eval_episode_numbers = np.arange(1, len(self.eval_rewards) + 1) * self.config.validation_freq
            ax.plot(eval_episode_numbers, self.eval_rewards, marker='o', linestyle='-', color='teal', label='Avg Eval Reward')
            # ax.set_xticks(eval_episode_numbers) # Optional: set ticks exactly at eval points
        ax.set_xlabel('Episode Number')
        ax.set_ylabel('Avg Reward (Profit)')
        ax.set_title('Validation Performance')
        ax.legend(fontsize='small'); ax.grid(True)

        # --- Final Touches ---
        # Update the main figure title
        fig.suptitle(f'TD3 Training Progress (Episode {num_episodes})', fontsize=14)
        # Adjust layout to prevent labels/titles overlapping
        fig.tight_layout(rect=[0, 0.03, 1, 0.96]) # Leave space for suptitle

        # Redraw the canvas to show the updated plots
        fig.canvas.draw()
        fig.canvas.flush_events() # Ensure updates are processed in notebook environment
        logging.debug("Training dashboard updated.")

    def get_final_metrics(self) -> Dict:
        """
        Calculates and returns a dictionary of final summary performance metrics
        after training has finished.

        Returns:
            Dict: A dictionary containing key summary metrics.
        """
        # Return status if no training data exists
        if not self.episode_rewards:
             return {"status": "No training data available"}

        # Helper function for safe mean calculation (handles NaN/Inf)
        def safe_mean(data: List[float]) -> float:
             series = pd.Series(data).replace([np.inf, -np.inf], np.nan).dropna()
             return series.mean() if not series.empty else 0.0

        # Calculate final metrics using safe mean
        avg_profit_last_10 = safe_mean(self.episode_profits[-10:])
        # Calculate max profit safely
        max_profit = pd.Series(self.episode_profits).replace([np.inf, -np.inf], np.nan).max()
        avg_load_factor_last_10 = safe_mean(self.episode_load_factors[-10:])
        avg_eval_reward = safe_mean(self.eval_rewards)

        # Compile metrics into a dictionary
        metrics_summary = {
            'total_episodes_run': len(self.episode_rewards),
            'avg_profit_last_10_eps': avg_profit_last_10,
            'max_profit_episode': max_profit if pd.notna(max_profit) else 0.0,
            'avg_load_factor_last_10_eps': avg_load_factor_last_10,
            'avg_evaluation_reward': avg_eval_reward
            # Add other relevant final stats if needed (e.g., final losses)
        }
        return metrics_summary
if 'env' in locals() and env is not None:
        try:
             metrics_tracker = PerformanceMetrics(
             config=config,
             n_flights=env.n_flights
             # price_range is taken directly from config inside the class
             )
             print("‚úÖ PerformanceMetrics tracker initialized.")
        except :
            print(f"‚ùå Failed to initialize PerformanceMetrics tracker: {e}")

### 5.2 Live Training Dashboard Setup
*Creates the plot area for live updates during training.*

In [None]:
# Create the figure and axes for the dashboard *once*
import matplotlib.pyplot as plt
from IPython.display import display

# Initialize the dashboard plot structure
def create_dashboard_figure(figsize=(12, 8)):
    plt.ioff() # Turn off interactive mode initially to prevent double plots
    fig, axs = plt.subplots(3, 2, figsize=figsize)
    fig.canvas.header_visible = False # Hide the default toolbar for cleaner look
    fig.suptitle('Training Dashboard (Pending Start...)', fontsize=14)
    plt.ion() # Turn interactive mode back on for updates
    return fig, axs

dashboard_fig, dashboard_axs = create_dashboard_figure()
display(dashboard_fig.canvas) # Display the canvas area in the output

### 5.3 Evaluation Function (`evaluate_model`)
 **Runs** the trained agent in the environment for a specified number of episodes **without exploration noise** to objectively assess its learned policy's performance (greedy execution).

In [None]:
def evaluate_model(env: AirlinePricingEnv,
                   agent: TD3Agent,
                   config: Config,
                   num_eval_episodes: int = 5) -> float:
    """
    Evaluates the agent's deterministic policy over a specified number of episodes.

    Args:
        env (AirlinePricingEnv): The instantiated environment to evaluate in.
        agent (TD3Agent): The trained TD3 agent.
        config (Config): The configuration object (used for simulation length).
        num_eval_episodes (int): The number of episodes to run for evaluation.

    Returns:
        float: The average total reward (profit) achieved across all evaluation episodes.
               Returns -infinity if evaluation cannot proceed due to missing components.
    """
    # --- Pre-Evaluation Checks ---
    if not all([env, agent, config]):
        logging.error("Evaluation cannot start. Environment, Agent, or Config is missing.")
        return -np.inf # Return a value indicating failure

    logging.info(f"Starting evaluation phase for {num_eval_episodes} episode(s)...")

    # --- Set Agent to Evaluation Mode ---
    # This is crucial:
    # 1. Disables dropout layers (if any).
    # 2. Changes behavior of batch normalization layers (uses running stats instead of batch stats).
    # 3. Ensures deterministic behavior for evaluation.
    agent.actor.eval()

    # --- Evaluation Loop ---
    total_rewards = [] # List to store the total reward for each evaluation episode

    for episode in range(num_eval_episodes):
        # --- Episode Initialization ---
        # Reset the environment to get the starting state (as NumPy)
        state_np = env.reset().cpu().numpy()
        episode_reward = 0.0 # Accumulator for the current episode's reward
        done = False           # Flag to track episode termination

        # Initialize progress bar for the evaluation episode
        pbar = tqdm(total=config.simulation_length_days,
                    desc=f"Eval Episode {episode + 1}/{num_eval_episodes}",
                    leave=False) # Keep bar visible until loop finishes

        # Disable gradient calculations during evaluation (improves performance and saves memory)
        with torch.no_grad():
            # --- Step Loop within Evaluation Episode ---
            while not done:
                # 1. Select Action Deterministically (Greedy Policy)
                # Call agent's action selection method with exploration_noise=0.0
                action = agent.select_action(state_np, exploration_noise=0.0)

                # 2. Step the Environment using the chosen action
                # Env expects action in [-1, 1] range
                next_state_tensor, reward, done, info = env.step(action)
                # Convert next state to NumPy for the next iteration's input
                next_state_np = next_state_tensor.cpu().numpy()

                # --- Update State and Accumulate Reward ---
                state_np = next_state_np  # Transition to the next state
                episode_reward += reward  # Add step reward to episode total

                # --- Update Progress Bar ---
                pbar.update(1)
                # Display current accumulated profit for the episode
                pbar.set_postfix({'Profit': f'{episode_reward:,.0f}'})

                # --- Check for Episode Termination ---
                # The 'done' flag comes directly from the environment step
                if done:
                    pbar.close() # Close the progress bar for this episode
                    break        # Exit the inner step loop

        # --- End of Evaluation Episode ---
        # Store the total reward achieved in this episode
        total_rewards.append(episode_reward)
        logging.info(f"  Evaluation Episode {episode + 1} finished. Total Reward: {episode_reward:,.2f}")

    # --- Post-Evaluation ---
    # Set the Actor network back to training mode (enables dropout/batch norm training behavior if used)
    agent.actor.train()

    # Calculate the average reward across all evaluation episodes
    average_reward = np.mean(total_rewards) if total_rewards else 0.0 # Handle case of zero episodes
    logging.info(f"Evaluation phase finished. Average Reward over {num_eval_episodes} episodes: {average_reward:,.2f}")

    return average_reward

### 5.4 Training Function (`train_model`) 
 **Contains** the core logic for iterating through episodes, interacting with the environment using the TD3 agent, collecting experience, triggering agent learning updates, performing periodic validation, implementing early stopping, and updating the live dashboard.

In [None]:
def train_model(env: AirlinePricingEnv,
                agent: TD3Agent,
                metrics: PerformanceMetrics,
                config: Config,
                live_fig: plt.Figure,           # Live dashboard figure object
                live_axs: np.ndarray,           # Live dashboard axes array
                writer: Optional[writer] = None): # TensorBoard writer
    """
    Main training loop for the TD3 agent in the Airline Pricing Environment.

    Args:
        env: The instantiated AirlinePricingEnv.
        agent: The instantiated TD3Agent.
        metrics: The instantiated PerformanceMetrics tracker.
        config: The configuration object with hyperparameters.
        live_fig: The Matplotlib figure for the live dashboard.
        live_axs: The Matplotlib axes array for the live dashboard.
        writer: Optional TensorBoard SummaryWriter for logging.
    """
    # --- Pre-Training Checks ---
    # Ensure all necessary components are provided and valid
    if not all([env, agent, metrics, config, live_fig, live_axs]):
        logging.error("Training cannot start. Required components (env, agent, metrics, config, fig, axs) are missing or invalid.")
        return

    logging.info(f"Starting TD3 training for {config.n_episodes} episodes...")
    logging.info(f"  Hyperparameters: Batch={config.batch_size}, ActorLR={config.actor_lr:.1e}, CriticLR={config.critic_lr:.1e}, PolicyFreq={config.policy_freq}")
    logging.info(f"  Initial random exploration steps: {config.start_timesteps}")
    logging.info(f"  Validation every {config.validation_freq} episodes, Patience={config.patience}")

    # --- Initialization ---
    best_eval_reward = -np.inf     # Track the best average reward achieved during validation
    patience_counter = 0         # Counter for early stopping mechanism
    global_step_counter = 0      # Track total steps across all episodes for logging/buffer check

    # --- Main Training Loop (over episodes) ---
    for episode_idx in range(config.n_episodes):
        # --- Episode Initialization ---
        # Reset the environment to get the initial state (as NumPy array for buffer)
        state_np = env.reset().cpu().numpy()
        # Reset episode-specific trackers
        episode_reward = 0.0
        episode_steps = 0
        episode_critic_losses = [] # Store critic losses for averaging this episode
        episode_actor_losses = []  # Store actor losses for averaging this episode
        start_time = time.time()   # Track episode duration

        # Initialize progress bar for the current episode
        pbar = tqdm(total=config.simulation_length_days,
                    desc=f"Ep {episode_idx + 1}/{config.n_episodes}",
                    leave=False) # leave=False removes bar after completion

        # --- Inner Loop (over steps within an episode) ---
        while True:
            # Increment counters
            global_step_counter += 1
            episode_steps += 1

            # 1. Select Action based on exploration strategy
            if global_step_counter < config.start_timesteps:
                # Initial phase: Take purely random actions sampled uniformly from [-max_action, max_action]
                action = np.random.uniform(-config.max_action_value, config.max_action_value, size=env.action_size)
                logging.log(logging.DEBUG - 1, f"Step {global_step_counter}: Random action (startup)")
            else:
                # After initial phase: Use agent's policy + exploration noise
                action = agent.select_action(state_np, exploration_noise=config.exploration_noise)
                # Note: select_action handles noise addition and clipping internally

            # 2. Step the Environment
            # Environment accepts action in [-1, 1], performs internal scaling to price
            next_state_tensor, reward, done, info = env.step(action)
            # Convert next state tensor to NumPy for storing in the replay buffer
            next_state_np = next_state_tensor.cpu().numpy()

            # Environment interaction completed for this step

            # 3. Add Experience to Replay Buffer
            # Store the transition: (current_state, action_taken, reward_received, next_state, done_flag)
            # Note: 'done' flag here indicates terminal state of the episode
            agent.replay_buffer.add(state_np, action, reward, next_state_np, done)

            # 4. Agent Learning Step (Train Networks)
            critic_loss, actor_loss = None, None # Initialize losses for this step
            # Only start training updates after the initial random exploration phase
            if global_step_counter >= config.start_timesteps:
                # Call the agent's train method, which handles batch sampling and updates
                critic_loss, actor_loss = agent.train()
                # Store the losses if they were calculated (agent.train returns None if buffer too small)
                if critic_loss is not None:
                    episode_critic_losses.append(critic_loss)
                if actor_loss is not None: # Actor loss might be None due to delayed updates
                    episode_actor_losses.append(actor_loss)

            # --- Logging & Monitoring ---
            # Log step-level data to the PerformanceMetrics tracker
            metrics.log_step(episode_idx, episode_steps, reward, critic_loss, actor_loss, info)
            # Log step-level data to TensorBoard (optional)
            if writer:
                 writer.add_scalar('Reward/Step_Reward', reward, global_step_counter)
                 if critic_loss is not None: writer.add_scalar('Loss/Critic_Step', critic_loss, global_step_counter)
                 if actor_loss is not None: writer.add_scalar('Loss/Actor_Step', actor_loss, global_step_counter) # Log only when updated

            # --- State Transition and Progress Bar Update ---
            state_np = next_state_np # Move to the next state
            episode_reward += reward  # Accumulate reward for the episode
            pbar.update(1)            # Increment progress bar

            # Update postfix display on the progress bar
            pbar_postfix = {'Profit': f'{episode_reward:,.0f}'} # Show accumulated profit
            if episode_critic_losses: pbar_postfix['CritL'] = f'{np.mean(episode_critic_losses):.3f}' # Show running avg critic loss
            if episode_actor_losses: pbar_postfix['ActL'] = f'{np.mean(episode_actor_losses):.3f}'   # Show running avg actor loss
            pbar.set_postfix(pbar_postfix)

            # --- Episode Termination Check ---
            if done:
                pbar.close() # Close the progress bar for this episode
                break        # Exit the inner step loop

        # --- End of Episode Actions ---
        # Log summary metrics for the completed episode
        metrics.log_episode(episode_idx, episode_steps)
        episode_duration = time.time() - start_time
        logging.debug(f"Episode {episode_idx + 1} finished in {episode_duration:.1f} seconds.")

        # Log episode summary data to TensorBoard (optional)
        if writer:
            episode_num_tb = episode_idx + 1 # Use 1-based index for TensorBoard display
            # Check if metrics were successfully logged for this episode before accessing lists
            if episode_idx < len(metrics.episode_profits):
                 writer.add_scalar('Reward/Episode_Total_Profit', metrics.episode_profits[episode_idx], episode_num_tb)
                 writer.add_scalar('Loss/Critic_Episode_Avg', metrics.episode_losses_critic[episode_idx], episode_num_tb)
                 writer.add_scalar('Loss/Actor_Episode_Avg', metrics.episode_losses_actor[episode_idx], episode_num_tb)
                 writer.add_scalar('Metrics/Load_Factor', metrics.episode_load_factors[episode_idx], episode_num_tb)
                 writer.add_scalar('Timing/Episode_Duration_Sec', episode_duration, episode_num_tb)
                 # Add exploration noise parameter if needed
                 # writer.add_scalar('Params/Exploration_Noise', config.exploration_noise, episode_num_tb) # Assuming noise is constant for now

        # --- Live Plot Update ---
        # Call the metrics plotting function, passing the live figure and axes
        # This updates the dashboard displayed in the notebook output cell
        try:
            metrics.plot_training_progress(live_fig, live_axs, window=15) # Adjust smoothing window if desired
        except Exception as e:
            logging.error(f"Error updating live dashboard plot: {e}", exc_info=True)


        # --- Periodic Validation and Early Stopping ---
        current_episode_num = episode_idx + 1 # Use 1-based episode number for checks
        # Perform validation check at specified frequency or on the very last episode
        if current_episode_num % config.validation_freq == 0 or current_episode_num == config.n_episodes:
            logging.info(f"--- Running Validation after Episode {current_episode_num} ---")
            # Evaluate the agent's current deterministic policy
            # Use a fixed number of episodes (e.g., 3-5) for stable evaluation
            eval_avg_reward = evaluate_model(env, agent, config, num_eval_episodes=3)
            # Store the evaluation result
            metrics.add_eval_result(eval_avg_reward)
            logging.info(f"  Validation Average Reward: {eval_avg_reward:,.2f}")
            # Log validation reward to TensorBoard
            if writer: writer.add_scalar('Reward/Validation_Average', eval_avg_reward, current_episode_num)

            # --- Model Saving (Based on Validation Performance) ---
            # Check if the current validation reward is the best seen so far
            if eval_avg_reward > best_eval_reward:
                logging.info(f"  New best validation reward! {eval_avg_reward:,.2f} > {best_eval_reward:,.2f}. Saving model...")
                best_eval_reward = eval_avg_reward
                # Save both actor and critic models (TD3 requirement)
                agent.save(config.model_path_actor, config.model_path_critic)
                patience_counter = 0 # Reset patience since we found a better model
            else:
                # If performance did not improve, increment patience counter
                patience_counter += 1
                logging.info(f"  Validation reward did not improve ({eval_avg_reward:,.2f} <= {best_eval_reward:,.2f}). Patience: {patience_counter}/{config.patience}")

            # --- Early Stopping Check ---
            # If patience counter reaches the limit, stop training early
            if patience_counter >= config.patience:
                logging.warning(f"EARLY STOPPING triggered at episode {current_episode_num} after {config.patience} validations without improvement.")
                break # Exit the main training loop (for episodes)

    # --- Post-Training ---
    logging.info("--- Training Loop Finished ---")

    # Save the final state of the dashboard plot as a static image
    try:
        final_plot_path = os.path.join(config.output_dir, "td3_training_final_plot.png")
        live_fig.savefig(final_plot_path)
        logging.info(f"Final training plot saved to: {final_plot_path}")
    except Exception as e:
        logging.error(f"Failed to save final dashboard plot: {e}")

    # Display final summary metrics from the tracker
    final_summary = metrics.get_final_metrics()
    logging.info(f"Final Training Summary: {final_summary}")

### 5.5 Interactive Training Control
*Widgets to set key hyperparameters and initiate the training process.* This allows for easy experimentation without editing code directly.

In [None]:
# --- Define Interactive Widgets ---

# Button to start the training process
button_start_train = Button(
    description="‚ñ∂Ô∏è Start Training",
    button_style='success',  # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Click to begin training with the current settings',
    icon='play',              # FontAwesome icon name
    layout=Layout(width='auto', margin='10px 0 0 0') # Add some top margin
)

# Slider for Actor Learning Rate (Logarithmic steps often better for LR)
# Using FloatLogSlider requires ipywidgets >= 7.0.0
try:
    from ipywidgets import FloatLogSlider
    slider_actor_lr = FloatLogSlider(
        value=config.actor_lr,   # Default value from config
        base=10,                # Logarithmic base
        min=-5,                 # Exponent for min value (10^-5)
        max=-2,                 # Exponent for max value (10^-2)
        step=0.1,               # Step size for the exponent
        description='Actor LR:',
        readout_format='.1e',   # Display in scientific notation
        layout=Layout(width='95%')
    )
except ImportError:
     # Fallback to FloatSlider if FloatLogSlider is not available
     slider_actor_lr = FloatSlider(
        value=config.actor_lr,
        min=1e-5, max=1e-3, step=1e-5,
        description='Actor LR:', readout_format='.1e',
        layout=Layout(width='95%')
     )


# Slider for Critic Learning Rate
try:
    slider_critic_lr = FloatLogSlider(
        value=config.critic_lr, base=10, min=-4, max=-2, step=0.1,
        description='Critic LR:', readout_format='.1e',
        layout=Layout(width='95%')
    )
except ImportError:
     slider_critic_lr = FloatSlider(
        value=config.critic_lr, min=1e-4, max=5e-3, step=1e-4,
        description='Critic LR:', readout_format='.1e',
        layout=Layout(width='95%')
     )

# Slider for Number of Episodes
slider_episodes = IntSlider(
    value=config.n_episodes, # Default from config
    min=10,                 # Sensible minimum
    max=1000,               # Sensible maximum (adjust as needed)
    step=10,                # Step size
    description='Episodes:',
    layout=Layout(width='95%')
)

# --- Widget Event Handler ---

training_running = False # Global flag to prevent starting multiple training runs

def on_button_clicked(b: Button):
    """
    Callback function executed when the 'Start Training' button is clicked.
    Updates config, re-initializes optimizers if needed, and calls train_model.
    """
    global training_running # Use the global flag

    # Prevent starting if already running
    if training_running:
        print("‚ö†Ô∏è Training is already in progress or has finished for this session.")
        logging.warning("Attempted to start training while already running.")
        return

    print("üöÄ Training requested...")
    # Set flag and update button appearance to indicate running state
    training_running = True
    button_start_train.disabled = True
    button_start_train.description = "‚è≥ Training Running..."
    button_start_train.button_style = 'info'
    button_start_train.icon = 'spinner'

    # --- Update Configuration from Widgets ---
    # Read the current values from the sliders and update the global config object
    config.actor_lr = slider_actor_lr.value
    config.critic_lr = slider_critic_lr.value
    config.n_episodes = slider_episodes.value
    print(f"  Updated Config: Episodes={config.n_episodes}, ActorLR={config.actor_lr:.1e}, CriticLR={config.critic_lr:.1e}")

    # --- Re-initialize Optimizers (Crucial!) ---
    # If the agent exists, create new optimizer instances with the updated learning rates
    # This ensures the training uses the rates selected via the sliders.
    if 'agent' in globals() and agent is not None:
         try:
            agent.actor_optimizer = optim.Adam(agent.actor.parameters(), lr=config.actor_lr)
            agent.critic_optimizer = optim.Adam(agent.critic.parameters(), lr=config.critic_lr)
            print(f"  Agent optimizers re-initialized with updated learning rates.")
         except Exception as e:
             logging.error(f"Failed to re-initialize optimizers: {e}")
             # Optionally handle error, e.g., by stopping or using old optimizers


    # --- Call the Main Training Function ---
    print(f"  Starting training loop for {config.n_episodes} episodes...")
    # Ensure all necessary components are available before starting
    if ('env' in globals() and env is not None and
        'agent' in globals() and agent is not None and
        'metrics_tracker' in globals() and metrics_tracker is not None and
        'dashboard_fig' in globals() and 'dashboard_axs' in globals()): # Check for plot objects
         try:
            # Pass all required arguments, including the live plot figure and axes
            train_model(env=env, agent=agent, metrics=metrics_tracker, config=config,
                        live_fig=dashboard_fig, live_axs=dashboard_axs, # Pass dashboard objects
                        writer=writer) # Pass TensorBoard writer (might be None)
         except Exception as e:
            # Catch potential errors during the training loop execution
            logging.error(f"Exception occurred during train_model execution: {e}", exc_info=True)
            print(f"‚ùå Training encountered an error: {e}")
         finally:
            # --- Post-Training Cleanup (Button State) ---
            # Reset button state regardless of whether training finished successfully or errored
            button_start_train.disabled = False
            button_start_train.description = "‚ñ∂Ô∏è Start Training"
            button_start_train.button_style = 'success'
            button_start_train.icon = 'play'
            # Decide whether to reset `training_running` flag to allow re-runs
            # For safety, keeping it True might prevent accidental re-runs in the same kernel session
            # training_running = False # Uncomment to allow re-running training
            print("üèÅ Training function execution complete.")
    else:
        # If components are missing, report error and reset button
        print("‚ùå Cannot start training - one or more required components (Environment, Agent, Metrics Tracker, Dashboard Figure/Axes) are not ready.")
        button_start_train.disabled = False
        button_start_train.description = "‚ñ∂Ô∏è Start Training"
        button_start_train.button_style = 'success'
        button_start_train.icon = 'play'
        training_running = False # Allow trying again if components become ready later


# --- Link Button Click Event to Handler ---
button_start_train.on_click(on_button_clicked)

# --- Display Widgets ---
# Use VBox/HBox for better layout control
controls_box = VBox([
    HTML("<h4>Training Controls:</h4>"),
    slider_episodes,
    slider_actor_lr,
    slider_critic_lr,
    button_start_train
], layout=Layout(border='1px solid #ccc', padding='10px', margin='10px 0'))

display(controls_box)

## 6. Evaluation & Results
 ---
*Interactive result exploration and policy visualization.* Analyze the performance of the trained agent.

### 6.1 Run Final Evaluation
 *Load the best performing model saved during training (if it exists) and run the `evaluate_model` function to get a final performance measure.* Logs results and optionally sends final metrics to TensorBoard.


In [None]:
# --- Final Evaluation Trigger Cell ---
print("\n--- Final Model Evaluation ---")


# Default reward value if evaluation can't run
final_eval_reward = -np.inf

# Check if environment and agent are properly initialized
if 'env' in locals() and env is not None and 'agent' in locals() and agent is not None:

    # --- Load Best Saved Model ---
    # Attempt to load the actor and critic weights saved during training
    # These paths are defined in the config section
    model_loaded = agent.load(config.model_path_actor, config.model_path_critic)

    # Provide feedback on whether the model was loaded
    if model_loaded:
        print("‚úÖ Successfully loaded best saved Actor and Critic models.")
    else:
        print("‚ö†Ô∏è Could not load saved models (or paths invalid). Evaluating the agent's current state.")

    # --- Run Evaluation ---
    # Call the evaluation function with the (potentially loaded) agent
    final_eval_reward = evaluate_model(env, agent, config, num_eval_episodes=5) # Use 5 episodes for final eval
    print(f"\nüèÜ Final Average Evaluation Reward: {final_eval_reward:,.2f}")

    # --- Log Hyperparameters and Final Metrics to TensorBoard ---
    # Check if TensorBoard writer and metrics tracker were initialized successfully
    if 'writer' in locals() and writer is not None and 'metrics_tracker' in locals() and metrics_tracker is not None:
         try:
             # Prepare hyperparameters dictionary (flatten complex types like lists/dicts)
             hparam_dict_flat = {k:str(v) if isinstance(v,(dict,list)) else v for k,v in config.__dict__.items()}

             # Prepare final metrics dictionary
             # Start with the final evaluation reward
             metric_dict = {'hparam/final_eval_reward': final_eval_reward}
             # Get summary metrics from the tracker
             final_metrics_summary = metrics_tracker.get_final_metrics()
             # Safely add other summary metrics to the dictionary
             metric_dict['hparam/avg_profit_last10'] = final_metrics_summary.get('avg_profit_last_10_eps', 0)
             metric_dict['hparam/load_factor_last10'] = final_metrics_summary.get('avg_load_factor_last_10_eps', 0)
             metric_dict['hparam/max_profit_episode'] = final_metrics_summary.get('max_profit_episode', 0)

             # Get run name if defined, otherwise TensorBoard might use default naming
             run_name_tb = run_name if 'run_name' in locals() else None
             # Log the hyperparameters and final metrics
             writer.add_hparams(hparam_dict_flat, metric_dict, run_name=run_name_tb)
             print("Logged HParams and final metrics to TensorBoard.")
         except Exception as e:
              # Catch potential errors during logging
              logging.error(f"Failed to log HParams/Metrics to TensorBoard: {e}", exc_info=True)
    else:
        # Explain why HParams weren't logged if applicable
        if 'writer' not in locals() or writer is None: print("Skipping HParam logging: TensorBoard writer not initialized.")
        if 'metrics_tracker' not in locals() or metrics_tracker is None: print("Skipping HParam logging: Metrics tracker not initialized.")

else:
    # If env or agent weren't ready
    print("‚ùå Skipping final evaluation: Environment or Agent was not initialized successfully.")

### 6.2 Interactive Policy Visualization (Example)
 *Allows selecting a specific route and airline via dropdowns to visualize the pricing strategy learned by the agent over the simulation period.* This requires running a simulation loop using the agent's greedy policy.

In [None]:
# --- Interactive Policy Visualization Cell ---
print("\n--- Policy Visualization Tool ---")

if ('env' in locals() and env is not None and
    'agent' in locals() and agent is not None and
    'processed_data' in locals() and processed_data is not None and not processed_data.empty):

    # --- Prepare Widget Options ---
    # Get unique routes and airlines from the processed data for dropdowns
    try:
        route_options = sorted(processed_data['Route'].unique().tolist())
        airline_options = sorted(processed_data['Airline'].unique().tolist())
        # Check if options are valid
        if not route_options or not airline_options:
            raise ValueError("No routes or airlines found in processed data.")
    except Exception as e:
        print(f"‚ùå Error preparing widget options: {e}. Cannot create visualization tool.")
        route_options = ['N/A'] # Provide dummy options on error
        airline_options = ['N/A']

    # --- Define Interactive Function ---
    # Use the @interact decorator to automatically create widgets
    @interact
    def show_policy_plot(route=Dropdown(options=route_options, description="Route:"),
                         airline=Dropdown(options=airline_options, description="Airline:")):
        """
        This function is called whenever a dropdown selection changes.
        It simulates the agent's greedy policy for the selected flight
        and plots the resulting price trajectory.
        """
        # Avoid running if dummy options are selected
        if route == 'N/A' or airline == 'N/A':
            return

        print(f"üîç Simulating greedy policy for: {route} - {airline}...")
        # Double-check agent and env availability inside the function
        if not agent or not env:
            print("‚ùå Agent or Environment not ready.")
            return

        # Find the specific index for this flight combination in the environment's mapping
        try:
            flight_idx = env.flight_map[(route, airline)]
        except KeyError:
            print(f"‚ùå Error: Combination '{route}' - '{airline}' not found in environment's flight map.")
            return

        # --- Simulate Greedy Policy ---
        agent.actor.eval() # Set agent to evaluation mode
        state_np = env.reset().cpu().numpy() # Reset environment for simulation
        prices_over_time = [] # List to store prices for the selected flight
        days = []             # List to store corresponding dates
        current_sim_date = env.start_date # Start date for x-axis

        # Run simulation loop (similar to evaluation, but store specific price)
        with torch.no_grad():
            for day_step in range(config.simulation_length_days):
                # Get the full action vector from the agent (greedy)
                action_vector = agent.select_action(state_np, exploration_noise=0.0)

                # Extract the specific action for the selected flight index
                action_selected_flight = action_vector[flight_idx]

                # --- Convert action [-1, 1] to actual price ---
                # Use the environment's scaling parameters
                price = env.price_range[0] + (action_selected_flight + 1.0) * 0.5 * env.price_delta
                price = np.clip(price, env.price_range[0], env.price_range[1]) # Clip final price

                # Store the price and date for plotting
                prices_over_time.append(price)
                days.append(current_sim_date)

                # Step the *whole* environment with the full action vector to get the next state
                next_state_tensor, _, done, _ = env.step(action_vector)
                state_np = next_state_tensor.cpu().numpy() # Update state for next iteration
                current_sim_date += pd.Timedelta(days=1) # Increment date

                # Stop if the environment signals done (end of simulation period)
                if done:
                    break

        agent.actor.train() # Set agent back to training mode

        # --- Plotting Results ---
        if prices_over_time:
            plt.figure(figsize=(12, 5)) # Create a new figure for the plot
            plt.plot(days, prices_over_time, marker='.', linestyle='-', markersize=4, label=f'{airline} Price')
            plt.title(f"Agent's Learned Pricing Policy: {route} ({airline})")
            plt.xlabel("Date")
            plt.ylabel(f"Price ({config.data_paths['historical_data'].split('/')[-1].split('.')[0].capitalize()})") # Use currency/unit if known
            # Set y-axis limits slightly wider than the price range for better visibility
            plt.ylim(bottom=env.price_range[0] * 0.95, top=env.price_range[1] * 1.05)
            plt.grid(True, linestyle='--', alpha=0.6)
            plt.xticks(rotation=30, ha='right') # Rotate date labels for readability
            plt.legend()
            plt.tight_layout() # Adjust layout
            plt.show() # Display the plot
        else:
            print("‚ÑπÔ∏è No price data was generated during the simulation for this flight.")

else:
    # Message if essential components are missing
    print("‚ö†Ô∏è Skipping Policy Visualization: Environment, Agent, or Processed Data not available.")

## 7. Cleanup
 ---
 *Final steps, like closing the TensorBoard writer.*

### 7.1 Close TensorBoard Writer

In [None]:
if 'writer' in locals() and writer:
    try: writer.close(); print("\nTensorBoard writer closed.")
    except Exception as e: print(f"Error closing TB writer: {e}")

print("\n--- Notebook Execution Finished ---")