In [None]:
import time
import torch
import yaml
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import datetime
from collections import deque
import random
import os
import argparse
import logging
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from dataclasses import dataclass, field
from typing import Tuple, Dict, Any  # Import necessary types

# Type hinting and docstrings in Config
@dataclass
class Config:
    """Configuration for the airline pricing environment and agent."""
    seed: int = 42
    """Random seed for reproducibility."""
    price_scaler_range: Tuple[float, float] = (0.1, 0.9)
    """Range for the price scaler."""
    outlier_std_threshold: int = 3
    """Number of standard deviations for outlier removal."""
    max_days_ahead: int = 90
    """Maximum days ahead for booking."""
    simulation_length_days: int = 365
    """Length of the simulation in days."""
    seats_capacity: int = 150
    """Capacity of the aircraft."""
    action_days_ahead: int = 1
    """Number of days ahead the agent sets prices."""
    memory_size: int = 10000
    """Size of the agent's replay memory."""
    gamma: float = 0.99
    """Discount factor for future rewards."""
    epsilon: float = 1.0
    """Initial exploration rate."""
    epsilon_min: float = 0.01
    """Minimum exploration rate."""
    epsilon_decay: float = 0.998
    """Exploration rate decay factor."""
    learning_rate: float = 0.0005
    """Learning rate for the optimizer."""
    batch_size: int = 128
    """Batch size for training."""
    target_update_freq: int = 20
    """Frequency of target network updates."""
    patience: int = 20
    """Patience for early stopping."""
    n_episodes: int = 100
    """Number of training episodes."""
    data_generation_params: Dict[str, Any] = field(default_factory=lambda: {  # Use a dictionary
        'min_price': 2000.0,
        'max_price': 8000.0,
        'min_demand': 50,
        'max_demand': 150,
        'min_temp': 0.0,
        'max_temp': 35.0,
        'min_fuel_price': 2.0,
        'max_fuel_price': 4.0
    })
    """Parameters for data generation (if needed)."""



# Argument parsing for config file and overrides
parser = argparse.ArgumentParser(description="Train DQN agent for airline pricing.")
parser.add_argument('--config', type=str, default='config.yaml', help='Path to config file (YAML)')
parser.add_argument('--n_episodes', type=int, help='Override number of training episodes')  # Example override
# ... Add other argument overrides as needed ...

args = parser.parse_args()

# Load config from YAML or use defaults
try:
    with open(args.config, 'r') as f:
        config_dict = yaml.safe_load(f)
    config = Config(**config_dict)
except FileNotFoundError:
    logging.warning(f"Config file not found: {args.config}. Using default configuration.")
    config = Config()

# Override config parameters from command-line arguments
if args.n_episodes:
    config.n_episodes = args.n_episodes


# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Set seeds for reproducibility (using a single seed now)
random.seed(config.seed)
np.random.seed(config.seed)
torch.manual_seed(config.seed)
torch.backends.cudnn.deterministic = True  # Important for reproducibility with CUDA
torch.backends.cudnn.benchmark = False     # For fair comparison across runs

In [None]:
class DataProcessor:
    def __init__(self):
        self.state_scaler = StandardScaler()
        self.price_scaler = MinMaxScaler(feature_range=config.PRICE_SCALER_RANGE)
        self.demand_scaler = StandardScaler()
        self.route_encoder = {}
        self.airline_encoder = {}
        self.aircraft_encoder = {}
        self.weather_encoder = {}
        self.holiday_encoder = {}
        self.fuel_price_mean = None  # To store the mean fuel price.

    def fit(self, historical_data, fuel_prices, climate_data, holiday_data):
        if historical_data.empty or climate_data.empty or holiday_data.empty:  # fuel_prices can now be empty.
            raise ValueError("Input data cannot be empty (except fuel_prices, which is handled separately).")

        # Handle the case where fuel_prices is empty or None:
        if fuel_prices is None or fuel_prices.empty:
            logging.warning("fuel_prices DataFrame is empty.  Creating a placeholder with constant fuel price.")
            # Create a DataFrame with the same 'Date' range as historical_data and a constant 'FuelPrice'
            dates = historical_data['Date'].unique()
            self.fuel_price_mean = 2.5  # A reasonable default.  Could also be put in Config.
            fuel_prices = pd.DataFrame({'Date': dates, 'FuelPrice': self.fuel_price_mean})
            fuel_prices['Date'] = pd.to_datetime(fuel_prices['Date'])

        else: #if file is not empty.
            fuel_prices.loc[:, 'Date'] = pd.to_datetime(fuel_prices['Date'])
            if 'FuelPrice' not in fuel_prices.columns:
                raise ValueError("Missing FuelPrice column")
            self.fuel_price_mean = fuel_prices['FuelPrice'].mean()  # Use the actual mean
            fuel_prices = fuel_prices.fillna(self.fuel_price_mean) # added to fill any remaining NaN

        #The rest of your fit function is correct.

        _ = self.transform(historical_data, fuel_prices, climate_data, holiday_data)  # Call transform (for side effects like outlier removal, etc.)
        sample_state = [0.0, 1.0, 0.0, 2.5, 20.0, 0.0, 150.0, 300.0, 100.0, -0.5]
        sample_state.extend([0.0] * 5)  # 5 categorical features (route, airline, aircraft, weather, season)
        self.state_scaler.fit(np.array([sample_state]))
        self.demand_scaler.fit(historical_data[['Demand']])

        # Create the encoders
        self.route_encoder = {route: i for i, route in enumerate(historical_data['Route'].unique())}
        self.airline_encoder = {airline: i for i, airline in enumerate(historical_data['Airline'].unique())}
        self.aircraft_encoder = {aircraft: i for i, aircraft in enumerate(historical_data['AircraftType'].unique())}
        self.weather_encoder = {weather: i for i, weather in enumerate(climate_data['WeatherCondition'].unique())}
        self.holiday_encoder = {holiday: i for i, holiday in enumerate(holiday_data['HolidayName'].unique())}

        self.route_reference = list(self.route_encoder.keys())[0]
        self.airline_reference = list(self.airline_encoder.keys())[0]
        self.aircraft_reference = list(self.aircraft_encoder.keys())[0]
        self.weather_reference = list(self.weather_encoder.keys())[0]
        self.holiday_reference = list(self.holiday_encoder.keys())[0]
        return self
    def transform(self, historical_data, fuel_prices, climate_data, holiday_data):
        required_columns = {
            'historical_data': ['Date', 'Route', 'Airline', 'AircraftType', 'Price', 'Demand', 'Capacity'],
            'fuel_prices': ['Date', 'FuelPrice'],
            'climate_data': ['Date', 'Location', 'Temperature', 'WeatherCondition'],
            'holiday_data': ['Date', 'Location', 'HolidayName']
        }
        for df_name, cols in required_columns.items():
            df = locals()[df_name]
            missing_cols = [col for col in cols if col not in df.columns]
            if missing_cols:
                raise ValueError(f"Missing columns in {df_name}: {', '.join(missing_cols)}")


        historical_data = historical_data.copy()
        fuel_prices = fuel_prices.copy()
        climate_data = climate_data.copy()
        holiday_data = holiday_data.copy()

        # Date conversions and feature additions (using .loc for clarity)
        historical_data.loc[:, 'Date'] = pd.to_datetime(historical_data['Date'])
        fuel_prices.loc[:, 'Date'] = pd.to_datetime(fuel_prices['Date'])
        climate_data.loc[:, 'Date'] = pd.to_datetime(climate_data['Date'])
        holiday_data.loc[:, 'Date'] = pd.to_datetime(holiday_data['Date'])

        historical_data.loc[:, 'DayOfWeek'] = historical_data['Date'].dt.dayofweek
        historical_data.loc[:, 'Month'] = historical_data['Date'].dt.month
        historical_data.loc[:, 'IsWeekend'] = (historical_data['DayOfWeek'] >= 5).astype(int)
        historical_data.loc[:, 'Season'] = historical_data['Month'].map(self._get_season)

        historical_data = historical_data.fillna({
            'Demand': historical_data['Demand'].mean(),
            'Price': historical_data['Price'].mean(),
            'Capacity': historical_data['Capacity'].mode()[0]
        })
        fuel_prices = fuel_prices.fillna(fuel_prices['FuelPrice'].mean())
        climate_data = climate_data.ffill().bfill() # to fill the missing value.
        holiday_data = holiday_data.fillna('None')

        historical_data = self._remove_outliers(historical_data, ['Price', 'Demand'])
        historical_data = self._merge_data(historical_data, fuel_prices, climate_data, holiday_data)
        historical_data.loc[:, 'Demand_MA7'] = historical_data.groupby(['Route', 'Airline'])['Demand'].transform(
            lambda x: x.rolling(window=7, min_periods=1).mean()
        )

        def calculate_elasticity(group):
            price_pct = group['Price'].pct_change()
            demand_pct = group['Demand'].pct_change()
            elasticity = np.where(price_pct != 0, demand_pct / price_pct, 0)
            return pd.Series(elasticity, index=group.index)

        historical_data = historical_data.sort_values(['Route', 'Airline', 'Date'])
        historical_data.loc[:, 'PriceElasticity'] = historical_data.groupby(
            ['Route', 'Airline'], group_keys=False).apply(calculate_elasticity, include_groups=False).fillna(0)

        return historical_data


    def _get_season(self, month):
        if month in [12, 1, 2]: return 'Winter'
        if month in [3, 4, 5]: return 'Spring'
        if month in [6, 7, 8]: return 'Summer'
        return 'Fall'

    def _remove_outliers(self, df, columns, n_std=config.OUTLIER_STD_THRESHOLD):
        for column in columns:
            mean = df[column].mean()
            std = df[column].std()
            df = df[(df[column] <= mean + (n_std * std)) & (df[column] >= mean - (n_std * std))]
        return df

    def _merge_data(self, historical_data, fuel_prices, climate_data, holiday_data):
        """Merge all data sources."""
        try:
            historical_data = pd.merge(historical_data, fuel_prices, on='Date', how='left')
            if 'FuelPrice' not in historical_data.columns:
                raise ValueError("Failed to merge fuel prices")

            historical_data['Origin'] = historical_data['Route'].apply(lambda x: x.split('-')[0])
            historical_data = pd.merge(historical_data, climate_data,
                                       left_on=['Date', 'Origin'],
                                       right_on=['Date', 'Location'],
                                       how='left')
            if 'Temperature' not in historical_data.columns or 'WeatherCondition' not in historical_data.columns:
                 raise ValueError("Failed to merge climate data")
            historical_data.drop('Location', axis=1, inplace=True)

            historical_data = pd.merge(historical_data, holiday_data,
                                       left_on=['Date', 'Origin'],
                                       right_on=['Date', 'Location'],
                                       how='left')
            if 'HolidayName' not in historical_data.columns:
                raise ValueError("Failed to merge holiday data")

            historical_data['IsHoliday'] = (historical_data['HolidayName'].notna()).astype(int)
            historical_data.drop(['Location', 'HolidayName'], axis=1, inplace=True)
            return historical_data

        except KeyError as e:
            logging.error(f"KeyError during merge: {e}")
            raise
        except ValueError as e:
            logging.error(f"ValueError during merge: {e}")
            raise
        except Exception as e:
            logging.error(f"Unexpected error during merge: {e}")
            raise

    def scale_prices(self, prices):
        prices = prices.cpu().numpy() if isinstance(prices, torch.Tensor) else np.array(prices)
        prices = np.clip(prices, 0, self.historical_price_max)
        scaled = self.price_scaler.transform(prices.reshape(-1, 1)).flatten()

        return torch.tensor(scaled, dtype=torch.float32)


    def inverse_scale_prices(self, scaled_prices):
        scaled_prices = scaled_prices.cpu().numpy() if isinstance(scaled_prices, torch.Tensor) else np.array(scaled_prices)
        return self.price_scaler.inverse_transform(scaled_prices.reshape(-1, 1)).flatten()


    def k_encode(self, value, encoder, reference_value):
      encoding = [0] * (len(encoder) - 1)
      if value != reference_value:
          index = encoder.get(value)
          if index is not None:
            adjusted_index = index - (index > list(encoder.values())[list(encoder.keys()).index(reference_value)])
            if adjusted_index < len(encoding):
                encoding[adjusted_index] = 1
      return encoding

In [None]:
class AirlinePricingEnv:
    def __init__(self, historical_data, fuel_prices, climate_data, holiday_data, data_processor):
        if historical_data.empty or fuel_prices.empty or climate_data.empty or holiday_data.empty:
            raise ValueError("Input data cannot be empty")

        if not isinstance(data_processor, DataProcessor):
            raise ValueError("data_processor must be an instance of DataProcessor")

        self.data_processor = data_processor
        self.historical_data = self.data_processor.transform(historical_data, fuel_prices, climate_data, holiday_data)
        self.routes = self.historical_data['Route'].unique()
        self.airlines = self.historical_data['Airline'].unique()
        self.aircraft_types = self.historical_data['AircraftType'].unique()
        self.current_date = self.historical_data['Date'].min()
        self.action_days_ahead = config.ACTION_DAYS_AHEAD  # Define how many days ahead the agent sets prices
        self.simulation_length_days = config.SIMULATION_LENGTH_DAYS
        self.seats_capacity = config.SEATS_CAPACITY
        self.prices = {}  # (route, airline): price. Store ONLY current prices.
        self.seats_sold = {}
        self.data_processor.historical_price_max = self.historical_data['Price'].max()
        self.data_processor.price_scaler.fit(self.historical_data['Price'].values.reshape(-1, 1))
        self.seasonal_indices = self._calculate_seasonal_indices()
        self.historical_price_mean = self.historical_data['Price'].mean()
        self.preprocessed_data = self._preprocess_data()  # Preprocess for faster lookups
        self.reset()

    def _preprocess_data(self):
        """Pre-processes historical data for fast lookups in _get_demand."""
        data_dict = {}
        for (route, airline), group in self.historical_data.groupby(['Route', 'Airline']):
            data_dict[(route, airline)] = group.set_index('Date').to_dict('index')
        return data_dict

    def reset(self, historical_data=None):
        if historical_data is None:
            self.current_date = self.historical_data['Date'].min()
            historical_data = self.historical_data
        else:
            self.current_date = historical_data['Date'].min()
        self.current_step = 0
        self.prices = {}  # Reset prices
        self.seats_sold = {}
        return self._get_state(historical_data)
    

    def _get_state(self, historical_data):
       try:
            flight_date = self.current_date
            route = self.routes[0]
            airline = self.airlines[0]

            # Use preprocessed data for faster lookup
            data_for_route_airline = self.preprocessed_data.get((route, airline))
            if data_for_route_airline:
                row = data_for_route_airline.get(flight_date)
            else:
                row = None

            if row is None:
                day_of_week, month, is_weekend, season = flight_date.weekday(), flight_date.month, 1 if flight_date.weekday() >= 5 else 0, self.data_processor._get_season(flight_date.month)
                fuel_price, temperature, weather_condition, is_holiday = self.historical_data['FuelPrice'].mean(), 25, 'Sunny', 0
                demand_ma7, price_elasticity, seats_sold, current_price = self.historical_data['Demand_MA7'].mean(), 0, 0, 0.0
            else:
                day_of_week, month, is_weekend, season = int(row['DayOfWeek']), int(row['Month']), int(row['IsWeekend']), row['Season']
                fuel_price, temperature, weather_condition = float(row['FuelPrice']), float(row['Temperature']), row['WeatherCondition']
                is_holiday, demand_ma7, price_elasticity = int(row['IsHoliday']), float(row['Demand_MA7']), float(row['PriceElasticity'])
                seats_sold = self.seats_sold.get((route, airline, flight_date), 0)
                current_price = self.prices.get((route, airline), 0.0)  # Get current price for the (route, airline)

            route_encoded = self.data_processor.k_encode(route, self.data_processor.route_encoder, self.data_processor.route_reference)
            airline_encoded = self.data_processor.k_encode(airline, self.data_processor.airline_encoder, self.data_processor.airline_reference)
            aircraft_type = self.historical_data[(self.historical_data['Route'] == route) & (self.historical_data['Airline'] == airline)]['AircraftType'].mode()[0]
            aircraft_encoded = self.data_processor.k_encode(aircraft_type, self.data_processor.aircraft_encoder, self.data_processor.aircraft_reference)
            weather_encoded = self.data_processor.k_encode(weather_condition, self.data_processor.weather_encoder, self.data_processor.weather_reference)
            season_encoded = self.data_processor.k_encode(season, {'Winter': 0, 'Spring': 1, 'Summer': 2, 'Fall': 3}, 'Winter')
            remaining_capacity = self.seats_capacity - seats_sold

            state = [float(day_of_week), float(month), float(is_weekend), float(fuel_price),
                    float(temperature), float(is_holiday), float(remaining_capacity), float(current_price),
                    float(demand_ma7), float(price_elasticity)]
            state.extend([float(self.data_processor.route_encoder.get(route, 0.0)),
                          float(self.data_processor.airline_encoder.get(airline, 0.0)),
                          float(self.data_processor.aircraft_encoder.get(aircraft_type, 0.0)),
                          float(self.data_processor.weather_encoder.get(weather_condition, 0.0)),
                          float({'Winter': 0, 'Spring': 1, 'Summer': 2, 'Fall': 3}.get(season, 0.0))])

            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            scaled_state = torch.from_numpy(self.data_processor.state_scaler.transform(state_tensor.numpy())).float()

            if torch.isnan(scaled_state).any() or torch.isinf(scaled_state).any():
                logging.warning("Scaled state contains NaN/inf values; using fallback.")
                scaled_state = torch.zeros_like(scaled_state)
            return scaled_state.squeeze(0)

       except Exception as e:
            logging.error(f"Error generating state: {str(e)}")
            state_size = 15  #  Adjust as needed
            return torch.zeros(state_size, dtype=torch.float32)


    def _calculate_seasonal_indices(self):
        seasonal_indices = {}
        for (route, airline), group in self.historical_data.groupby(['Route', 'Airline']):
            seasonal_avg = group.groupby('Season')['Demand'].mean()
            overall_avg = group['Demand'].mean()
            seasonal_indices[(route, airline)] = seasonal_avg / overall_avg
        return seasonal_indices

    def _get_demand(self, route, airline, flight_date, price):
        """Gets demand using the preprocessed data dictionary."""
        try:
            data_for_route_airline = self.preprocessed_data[(route, airline)]
            row = data_for_route_airline.get(flight_date)

            if row is None:
                # Handle cases where there's no exact date match.
                # Find closest date (using idxmin, as before, but on a smaller dataset)
                dates = np.array(list(data_for_route_airline.keys()))
                closest_date_idx = np.argmin(np.abs(dates - flight_date))
                closest_date = dates[np.argmin(np.abs(dates - flight_date))]
                row = data_for_route_airline[closest_date]
                base_demand, price_elasticity = row['Demand_MA7'], row['PriceElasticity']
            else:
                base_demand, price_elasticity = row['Demand_MA7'], row['PriceElasticity']

        except KeyError:
            # Handle cases where the (route, airline) combination doesn't exist.
            base_demand, price_elasticity = self.historical_data['Demand'].mean(), -0.5


        seasonal_index = self.seasonal_indices.get((route, airline), {}).get(self.data_processor._get_season(flight_date.month), 1.0)
        base_demand *= seasonal_index

        # Check if the current date is holiday. No need to search entire historical data
        is_holiday = int(row['IsHoliday']) if row and 'IsHoliday' in row else 0
        base_demand *= 1.5 if is_holiday == 1 else 1

        demand = int(base_demand * (1 + price_elasticity * ((price - self.historical_price_mean) / self.historical_price_mean)))
        return max(0, int(demand))


    def step(self, action):
        """Executes one step (one day) in the environment.

        Args:
            action (torch.Tensor): Pricing actions for all routes/airlines/dates

        Returns:
            tuple: Contains (next_state, reward, done, info) where:
                next_state (torch.Tensor): Next state representation
                reward (torch.Tensor): Total revenue from this step
                done (torch.Tensor): Whether simulation is complete
                info (dict): Additional information (currently empty)
        """
        # print("Shape of action inside step:", action.shape)
        total_revenue = 0

        # Convert action to a NumPy array if it's a tensor
        if isinstance(action, torch.Tensor):
            action = action.detach().cpu().numpy()

        action_idx = 0  # Initialize action index

        for route in self.routes:
            for airline in self.airlines:
              # Agent sets price for the current day (or a small number of future days)
                for days_ahead in range(self.action_days_ahead):
                    flight_date = self.current_date + pd.Timedelta(days=days_ahead)

                    # Scale the single price
                    price = self.data_processor.inverse_scale_prices(np.array([action]))[0]
                    # (route, airline) -> price. Setting price for the current day
                    self.prices[(route, airline)] = price

                    seats_sold_key = (route, airline, flight_date)
                    seats_already_sold = self.seats_sold.get(seats_sold_key, 0)
                    demand = self._get_demand(route, airline, flight_date, price)
                    remaining_capacity = self.seats_capacity - seats_already_sold
                    actual_demand = min(demand, remaining_capacity) # Actual demand
                    self.seats_sold[seats_sold_key] = seats_already_sold + actual_demand
                    total_revenue += actual_demand * price # Accumulate revenue
                    action_idx +=1

        # Advance the environment
        self.current_date += pd.Timedelta(days=1)
        self.current_step += 1
        done = torch.tensor(self.current_step >= self.simulation_length_days, dtype=torch.bool)
        next_state = self._get_state(self.historical_data)
        return next_state, torch.tensor(total_revenue, dtype=torch.float32), done, {} # Return a float reward

In [None]:
class DQNAgent(nn.Module):  # Inherit from nn.Module
    def __init__(self, state_size, action_size, device):
        super(DQNAgent, self).__init__()  # Call superclass constructor
        if state_size <= 0 or action_size <= 0:
            raise ValueError("state_size and action_size must be positive integers")

        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=config.MEMORY_SIZE)
        self.gamma = config.GAMMA
        self.epsilon = config.EPSILON
        self.epsilon_min = config.EPSILON_MIN
        self.epsilon_decay = config.EPSILON_DECAY
        self.learning_rate = config.LEARNING_RATE
        self.batch_size = config.BATCH_SIZE
        self.device = device
        self.model = self._build_model().to(self.device)  # Move model to device
        self.target_model = self._build_model().to(self.device)  # Move target model to device
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)  # Define optimizer
        self.loss_fn = nn.HuberLoss()  # Define loss function
        self.update_target_counter = 0
        self.target_update_frequency = config.TARGET_UPDATE_FREQ


    def _build_model(self):
        """Build the neural network model for Q-value approximation."""
        model = nn.Sequential(
            nn.Linear(self.state_size, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, self.action_size)
        )
        return model

    def remember(self, state, action, reward, next_state, done):
        """Store experience in the replay buffer."""
        # No changes here; we're storing the action index directly.
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state, training=True):
        if training and np.random.rand() <= self.epsilon:
             return torch.randint(0, self.action_size, (1,), device=self.device,dtype=torch.int64) #Return action index
        state = state.unsqueeze(0).to(self.device)
        with torch.no_grad():
             q_values = self.model(state)
        return torch.argmax(q_values, dim=1) # Return a LongTensor, shape [1]

    def replay(self, batch_size: int) -> float:
        if batch_size > len(self.memory):
            raise ValueError(f"Batch size {batch_size} exceeds memory size {len(self.memory)}")

        try:
            minibatch = random.sample(self.memory, batch_size)
            states, actions, rewards, next_states, dones = map(torch.stack, zip(*minibatch))
            states, actions, rewards, next_states, dones = (
                states.to(self.device), actions.to(self.device,dtype=torch.int64), rewards.to(self.device),
                next_states.to(self.device), dones.to(self.device)
            )

            current_q_values = self.model(states)  # [batch_size, action_size]

            with torch.no_grad():
                next_q_values = self.target_model(next_states)  # [batch_size, action_size]
            max_next_q_values = next_q_values.max(1)[0]  # [batch_size]

            # Correctly compute targets using gather and scatter_
            targets = current_q_values.clone()  # [batch_size, action_size]
            # actions = actions.long() # Convert actions to long, needed for gather and scatter. # No longer needed

            # Use gather to get the Q-values corresponding to the SELECTED ACTIONS
            selected_q_values = current_q_values.gather(1, actions)  # [batch_size, 1]

            # Compute the target values
            target_values = rewards + self.gamma * max_next_q_values * (~dones)  # [batch_size]

            # Use scatter to update ONLY the Q-values corresponding to the selected actions

            targets.scatter_(1, actions, target_values.unsqueeze(1)) #Corrected Line


            loss = self.loss_fn(current_q_values, targets)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
            logging.debug(f"Training loss: {loss.item():.4f}, Epsilon: {self.epsilon:.4f}")
            return loss.item()

        except Exception as e:
            logging.error(f"Error during replay: {e}")
            raise

    def load(self, name):
        self.model.load_state_dict(torch.load(name))  # Load state_dict
        self.target_model.load_state_dict(torch.load(name)) # Load into target_model too
        self.model.to(self.device) # Make sure its on the correct device
        self.target_model.to(self.device) # Make sure its on the correct device


    def save(self, name):
        torch.save(self.model.state_dict(), name)  # Save state_dict

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error

class PerformanceMetrics:
    """
    A class to track and calculate performance metrics during training and evaluation.
    """
    def __init__(self):
        self.episode_rewards = []  # List to store total reward per episode
        self.loss_history = []     # List to store average loss per episode
        self.demand_predictions = [] # to store the demand predictions for one episode
        self.actual_demands = []    # To store the actual demands to compare.
        self.price_predictions = []  # To store the prices set by agent.
        self.actual_prices = []      # To store the prices.

    def log_episode(self, total_reward, avg_loss):
        """Log performance for each training episode."""
        self.episode_rewards.append(total_reward)
        self.loss_history.append(avg_loss)

    def log_predictions(self, predicted_demand, actual_demand, predicted_price, actual_price):
        """
        Logs predictions and actual values for demand and price.

        Args:
            predicted_demand: The demand predicted by the agent.
            actual_demand: The actual demand observed in the environment.
            predicted_price: The price set by the agent.
            actual_price: The actual/historical price.
        """
        self.demand_predictions.append(predicted_demand)
        self.actual_demands.append(actual_demand)
        self.price_predictions.append(predicted_price)
        self.actual_prices.append(actual_price)

    def calculate_metrics(self):
        """Compute comprehensive performance metrics."""
        metrics = {
            'Mean Episode Reward': np.mean(self.episode_rewards),
            'Reward Std Deviation': np.std(self.episode_rewards),
            'Mean Training Loss': np.mean(self.loss_history),
        }

        # Ensure that predictions exist before calculation
        if self.demand_predictions and self.actual_demands:
            metrics['Demand MAPE'] = mean_absolute_percentage_error(
                self.actual_demands, self.demand_predictions
            )
            metrics['Demand RMSE'] = np.sqrt(mean_squared_error(
                self.actual_demands, self.demand_predictions
            ))
        else:
            metrics['Demand MAPE'] = None  # Or np.nan, or some other placeholder
            metrics['Demand RMSE'] = None  # Or np.nan

        if self.price_predictions and self.actual_prices:
             metrics['Price MAPE'] = mean_absolute_percentage_error(
                  self.actual_prices, self.price_predictions
              )
             metrics['Price RMSE'] = np.sqrt(mean_squared_error(
                  self.actual_prices, self.price_predictions
              ))
        else:
            metrics['Price MAPE'] = None
            metrics['Price RMSE'] = None

        return metrics


    def plot_training_progress(self):
      """Visualize training progression."""
      fig, axs = plt.subplots(2, 2, figsize=(15, 10))

      # Episode Rewards
      axs[0, 0].plot(self.episode_rewards)
      axs[0, 0].set_title('Episode Rewards')
      axs[0,0].set_xlabel('Episode')
      axs[0,0].set_ylabel('Reward')


      # Training Loss
      axs[0, 1].plot(self.loss_history)
      axs[0, 1].set_title('Training Loss')
      axs[0,1].set_xlabel('Episode')
      axs[0,1].set_ylabel('Loss')

      # Demand Prediction Error
      if self.demand_predictions and self.actual_demands: # Check if we have the necessary data.
        demand_errors = np.abs(np.array(self.actual_demands) - np.array(self.demand_predictions)) / (np.array(self.actual_demands) + 1e-9)  # Adding a small constant to avoid division by zero
        axs[1, 0].plot(demand_errors)
        axs[1, 0].set_title('Demand Prediction Error (Absolute Percentage Error)')
        axs[1,0].set_xlabel('Step')
        axs[1,0].set_ylabel('Error')
      else:
          axs[1, 0].set_title('Demand Prediction Error (Not Available)')

      # Price Prediction Error
      if self.price_predictions and self.actual_prices:
        price_errors = np.abs(np.array(self.actual_prices) - np.array(self.price_predictions)) / (np.array(self.actual_prices) + 1e-9)
        axs[1, 1].plot(price_errors)
        axs[1, 1].set_title('Price Prediction Error (Absolute Percentage Error)')
        axs[1,1].set_xlabel('Step')
        axs[1,1].set_ylabel('Error')
      else:
          axs[1, 1].set_title('Price Prediction Error (Not Available)')


      plt.tight_layout()
      plt.show()

In [None]:
def train_model(env, agent, n_episodes, validation_data=None, writer=None):
    """
    Trains the DQN agent.

    Args:
        env: The AirlinePricingEnv environment.
        agent: The DQNAgent.
        n_episodes: The number of training episodes.
        validation_data:  Optional validation data (not used in this version).
        writer:  The SummaryWriter for TensorBoard logging.

    Returns:
        The trained agent.
    """
    best_reward = float('-inf')
    patience_counter = 0
    metrics_tracker = PerformanceMetrics()

    for episode in range(n_episodes):
        state = env.reset()
        state = state.to(agent.device)
        total_reward = 0
        losses = []

        with tqdm(total=env.simulation_length_days, desc=f"Episode {episode + 1}/{n_episodes}", unit="step") as pbar:
            for time in range(env.simulation_length_days):
                action = agent.act(state)  # Get the action INDEX
                start_time = time.time()
                next_state, reward, done, _ = env.step(action)
                step_time = time.time() - start_time

                next_state = next_state.to(agent.device)

                # Get the actual demand and price for logging:
                route = env.routes[0]  # Assuming single route for now
                airline = env.airlines[0]  # Assuming single airline for now
                flight_date = env.current_date - pd.Timedelta(days=1)  # Get *previous* day (step advances the date)
                # Get demand using the function in environment
                demand = env._get_demand(route, airline, flight_date, env.prices[(route, airline)])
                # Get price using the environment price
                price = env.prices[(route, airline)]
                metrics_tracker.log_predictions(demand, env.historical_data[(env.historical_data['Date'] == flight_date)
                                                                         & (env.historical_data['Route'] == route)
                                                                         & (env.historical_data['Airline'] == airline)]['Demand'].iloc[0], price, env.historical_data[(env.historical_data['Date'] == flight_date)
                                                                                                                & (env.historical_data['Route'] == route)
                                                                                                                & (env.historical_data['Airline'] == airline)]['Price'].iloc[0]  )



                agent.remember(state, action, reward, next_state, done)

                if len(agent.memory) > agent.batch_size:
                    loss = agent.replay(agent.batch_size)
                    losses.append(loss)
                    if writer:
                        writer.add_scalar("Loss/train", loss, episode * env.simulation_length_days + time)

                state = next_state
                total_reward += reward  # Accumulate the reward
                agent.update_target_counter += 1

                if agent.update_target_counter >= agent.target_update_frequency:
                    agent.target_model.load_state_dict(agent.model.state_dict())
                    agent.update_target_counter = 0

                pbar.update(1)
                pbar.set_postfix({
                    "Reward": f"{total_reward:.2f}",
                    "Avg Loss": f"{np.mean(losses) if losses else 0:.4f}",
                    "Epsilon": f"{agent.epsilon:.4f}",
                    "Step Time": f"{step_time:.2f}s"
                })

                if done:
                    break

        if validation_data is not None:
          val_reward = evaluate_model(env, agent, validation_data)
          if val_reward > best_reward:
              best_reward = val_reward
              patience_counter = 0
              agent.save('best_model.pth')
          else:
              patience_counter += 1
          if patience_counter >= config.PATIENCE:
            logging.info("Early stopping triggered!")
            break
          if writer:
            writer.add_scalar("Reward/validation", val_reward, episode)


        avg_loss = np.mean(losses) if losses else 0.0
        metrics_tracker.log_episode(total_reward, avg_loss, total_reward)
        logging.info(f"Episode: {episode + 1}/{n_episodes}, Total Reward: {total_reward}, Avg Loss: {avg_loss}, Epsilon: {agent.epsilon:.3f}")
        if writer:
            writer.add_scalar("Reward/train", total_reward, episode)
            writer.add_scalar("Epsilon", agent.epsilon, episode)

    # Calculate and plot metrics after training is complete
    performance_metrics = metrics_tracker.calculate_metrics()
    print(performance_metrics)
    metrics_tracker.plot_training_progress()

    return agent

def evaluate_model(env, agent, validation_data):
    """Evaluate the trained agent."""
    total_reward = 0
    state = env.reset(validation_data)
    state = state.to(agent.device)  # Move state to device
    for time in range(env.simulation_length_days):
        action = agent.act(state, training=False)  # Get action (no exploration)
        next_state, reward, done, _ = env.step(action)  # Step the environment
        next_state = next_state.to(agent.device)
        total_reward += reward  # Accumulate reward
        state = next_state # update the state
        if done:
            break
    return total_reward

In [None]:
def create_holiday_data(historical_data):
    """Creates a sample holiday dataset."""

    # Get unique dates and locations
    dates = historical_data['Date'].unique()
    # Expecting Route to be in format 'Origin-Destination'
    locations = np.unique([route.split('-')[0] for route in historical_data['Route'].unique()])

    holiday_data = []
    for location in locations:
        # Create 2-4 random holidays per location
        num_holidays = random.randint(2, 4)  
        for _ in range(num_holidays):
            # Randomly select a date
            random_date = pd.to_datetime(random.choice(dates))

            # Create a holiday name (you can customize this)
            holiday_name = f"{location} Holiday {_+1}"

            holiday_data.append({'Date': random_date, 'Location': location, 'HolidayName': holiday_name})

    return pd.DataFrame(holiday_data)

def setup_gpu(disable_gpu=False, memory_limit=None):
    """Configures GPU usage for PyTorch.

    Args:
        disable_gpu (bool): Disable GPU use.
        memory_limit (int):  (Not directly applicable to PyTorch in the same way as TF).

    Returns:
        torch.device: The device to use (CPU or CUDA).
    """
    if disable_gpu:
        logging.info("GPU disabled by user request. Running on CPU.")
        return torch.device("cpu")

    if torch.cuda.is_available():
        device = torch.device("cuda")
        logging.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
        return device
    else:
        logging.info("No GPU found, using CPU.")
        return torch.device("cpu")


def add_gpu_arguments(parser):
    """Adds GPU-related arguments to an ArgumentParser."""
    group = parser.add_argument_group('GPU Settings')
    group.add_argument('--disable_gpu', action='store_true',
                       help='Disable GPU usage, even if available.')
    group.add_argument('--memory_limit', type=int, default=None,
                       help='(Not directly used in PyTorch) Limit GPU memory usage (in MB).')  # Kept for compatibility, but won't be used

    return parser


def main(n_episodes, batch_size, learning_rate, gamma, epsilon_decay, target_update_freq, device):
    """Main function to run the airline pricing simulation."""

    logging.info("Loading data from CSV...")
    # Load dataframes.  
    historical_data = pd.read_csv('flight_data_BOM_BLR.csv')
    historical_data2 = pd.read_csv('flight_data_DEL_BLR.csv')
    historical_data3 = pd.read_csv('flight_data_DEL_CCU.csv')

    historical_data = pd.concat([historical_data, historical_data2, historical_data3], ignore_index=True)

    # Load other necessary data
    climate_data = pd.read_csv('location_data.csv')

    # Since holiday data is not provided in the uploaded CSV, use _create_holiday_data
    holiday_data = create_holiday_data(historical_data)
    # Create a constant fuel price DataFrame, as you don't have a separate file
    fuel_prices = pd.DataFrame({'Date': historical_data['Date'].unique(), 'FuelPrice': 2.5})  # Example constant fuel price
    fuel_prices['Date'] = pd.to_datetime(fuel_prices['Date'])

    # Initialize data processor and fit it with the data
    logging.info("Initializing data processor...")
    data_processor = DataProcessor()
    data_processor.fit(historical_data, fuel_prices, climate_data, holiday_data)


    logging.info("Creating environment...")
    env = AirlinePricingEnv(historical_data, fuel_prices, climate_data, holiday_data, data_processor)

    state_size = env.reset().shape[0]  # Get state size directly from tensor
    action_size = len(env.routes) * len(env.airlines) * config.ACTION_DAYS_AHEAD #As the action is one value now.


    logging.info("Creating DQN agent...")
    agent = DQNAgent(state_size, action_size, device)  # Pass device to agent
    agent.gamma = gamma
    agent.learning_rate = learning_rate
    agent.epsilon_decay = epsilon_decay
    agent.target_update_frequency = target_update_freq
    agent.batch_size = batch_size


    # Initialize TensorBoard writer
    writer = SummaryWriter()

    logging.info("Starting training...")
    trained_agent, performance_metrics = train_model(env, agent, n_episodes,  writer)

    logging.info("Evaluating final model...")
    final_reward = evaluate_model(env, trained_agent, historical_data)
    logging.info(f"Final reward: {final_reward}")

    logging.info("Saving model...")
    trained_agent.save('final_model.pth')  # Save as .pth
    writer.close()
    logging.info("Training complete!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train DQN agent for airline pricing.")
    parser = add_gpu_arguments(parser)
    # Removed file path arguments
    parser.add_argument('--n_episodes', type=int, default=config.N_EPISODES, help='Training episodes.')
    parser.add_argument('--batch_size', type=int, default=config.BATCH_SIZE, help='Training batch size.')
    parser.add_argument('--learning_rate', type=float, default=config.LEARNING_RATE, help='Optimizer learning rate.')
    parser.add_argument('--gamma', type=float, default=config.GAMMA, help='Discount factor.')
    parser.add_argument('--epsilon_decay', type=float, default=config.EPSILON_DECAY, help='Epsilon decay rate.')
    parser.add_argument('--target_update_freq', type=int, default=config.TARGET_UPDATE_FREQ, help='Target network update frequency.')
    args = parser.parse_args()

    device = setup_gpu(args.disable_gpu) # Removed args.memory_limit
    # Call main and pass config and device.
    main(args.n_episodes,args.batch_size,args.learning_rate,args.gamma,args.epsilon_decay,args.target_update_freq, device)