In [4]:
import sys
sys.path.append("..")
import os
import numpy as np
from stable_baselines3 import DQN
from src.environment import AircraftDisruptionEnv
from scripts.visualizations import StatePlotter
from scripts.utils import load_scenario_data
from src.config import *
import re
import torch
import time
import ipywidgets as widgets
from IPython.display import display, Image as IPImage
from io import BytesIO
import matplotlib.pyplot as plt
from datetime import timedelta

# Load the model and run inference
def run_inference_dqn(model_path, scenario_folder, env_type, seed, plot_title):
    # Load the scenario data
    data_dict = load_scenario_data(scenario_folder)

    # Extract necessary data for the environment
    aircraft_dict = data_dict['aircraft']
    flights_dict = data_dict['flights']
    rotations_dict = data_dict['rotations']
    alt_aircraft_dict = data_dict['alt_aircraft']
    config_dict = data_dict['config']

    # Initialize the environment
    env = AircraftDisruptionEnv(
        aircraft_dict, 
        flights_dict, 
        rotations_dict, 
        alt_aircraft_dict, 
        config_dict,
        env_type=env_type
    )

    # Load the trained model and set the environment
    model = DQN.load(model_path)
    model.set_env(env)

    # Set model to evaluation mode
    model.policy.set_training_mode(False)
    model.exploration_rate = 0.0

    # Set random seed for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)

    print(f"seed: {seed}")

    # Create StatePlotter object for visualizing the environment state
    state_plotter = StatePlotter(
        aircraft_dict=env.aircraft_dict,
        flights_dict=env.flights_dict,
        rotations_dict=env.rotations_dict,
        alt_aircraft_dict=env.alt_aircraft_dict,
        start_datetime=env.start_datetime,
        end_datetime=env.end_datetime,
        uncertain_breakdowns=env.uncertain_breakdowns,
        plot_title=plot_title
    )

    # Reset the environment for inference
    obs, _ = env.reset()
    done_flag = False
    total_reward = 0
    step_num = 0
    max_steps = 1000  # Set a maximum number of steps to prevent infinite loops
    
    # Keep track of all flights that were chosen as actions
    chosen_flight_actions = []

    while not done_flag and step_num < max_steps:
        # Visualize the current state
        print(f"Step {step_num}:")

        # Extract necessary information from the environment for plotting
        swapped_flights = env.swapped_flights
        environment_delayed_flights = env.environment_delayed_flights
        current_datetime = env.current_datetime

        # Retrieve the updated dictionaries from the environment
        updated_flights_dict = env.flights_dict
        updated_rotations_dict = env.rotations_dict
        updated_alt_aircraft_dict = env.alt_aircraft_dict
        print("**updated_alt_aircraft_dict:")
        print(updated_alt_aircraft_dict)
        cancelled_flights = env.penalized_cancelled_flights

        if DEBUG_MODE_VISUALIZATION:
            print("Flights Dict:")
            print(updated_flights_dict)
            print("Alt Aircraft Dict:")
            print(updated_alt_aircraft_dict)
            print("Swapped Flights:")
            print(swapped_flights)
            print("Environment Delayed Flights:")
            print(environment_delayed_flights)
            print("Cancelled Flights:")
            print(cancelled_flights)
            print("Unavailabilities:")
            print(env.alt_aircraft_dict)
            print("Uncertain Breakdowns:")
            for key, value in env.uncertain_breakdowns.items():
                print(f"{key}: {value}")
            print("Current Breakdowns:")
            print(env.current_breakdowns)
            print("")

        # Update the StatePlotter's dictionaries with the updated ones
        state_plotter.alt_aircraft_dict = updated_alt_aircraft_dict
        state_plotter.flights_dict = updated_flights_dict
        state_plotter.rotations_dict = updated_rotations_dict

        if 'reward' not in locals():
            reward = 0
            action = 0
        
        # Plot the current state
        fig = state_plotter.plot_state(
            updated_flights_dict, 
            swapped_flights, 
            environment_delayed_flights, 
            cancelled_flights, 
            current_datetime, 
            title_appendix=env_type,
            show_plot=True,
            reward_and_action=(reward, env.map_index_to_action(action), total_reward)
        )
        plt.show()

        # Get the action mask from the environment
        action_mask = obs['action_mask']

        # Convert observation to float32
        obs = {key: np.array(value, dtype=np.float32) for key, value in obs.items()}

        # Get the action mask from the observation
        action_mask = obs.get('action_mask', None)
        if action_mask is None:
            raise ValueError("Action mask is missing in the observation!")

        # Get the Q-values and apply the action mask
        obs_tensor = model.policy.obs_to_tensor(obs)[0]
        q_values = model.policy.q_net(obs_tensor).detach().cpu().numpy().squeeze()

        # Mask invalid actions by setting their Q-values to -inf
        masked_q_values = q_values.copy()
        masked_q_values[action_mask == 0] = -np.inf

        if env_type == 'reactive':
            print(f"*** amount of allowed actions: {np.sum(action_mask)}")
        # Predict the action using the masked Q-values
        action = np.argmax(masked_q_values)

        # Verify if the action is valid
        if action_mask[action] == 0:
            raise ValueError(f"Invalid action selected by the model: {action}")

        # Take action in the environment
        obs, reward, terminated, truncated, info = env.step(action)

        # Store the chosen flight action
        action_mapped = env.map_index_to_action(action)
        if action_mapped[0] != 'no_op':  # Only store if it's not a no_op action
            chosen_flight_actions.append(action_mapped[0])

        # Accumulate the reward
        total_reward += reward

        print("action index:")
        print(action)
        print("action mapped:")
        print(action_mapped)
        print(f"Action taken: {action_mapped}, Reward: {reward}")

        # Combine terminated and truncated flags
        done_flag = terminated or truncated

        step_num += 1

    print("================================================")
    print("Final state:")

    current_datetime += timedelta(hours=TIMESTEP_HOURS)

    # Create a dictionary marking all chosen flights as environment delayed
    final_environment_delayed = environment_delayed_flights.copy()
    for flight_id in chosen_flight_actions:
        if flight_id not in final_environment_delayed:
            final_environment_delayed[flight_id] = True

    # manually hard code a flight number in swapped flights
    swapped_flights = 

    # Plot the final state with all chosen flights marked
    fig = state_plotter.plot_state(
        updated_flights_dict, 
        swapped_flights, 
        final_environment_delayed,  # Use the modified dictionary here
        cancelled_flights, 
        current_datetime, 
        title_appendix=f"{env_type} - All Actions Highlighted",
        show_plot=True,
        reward_and_action=(reward, env.map_index_to_action(action), total_reward)
    )
    plt.show()

    print(f"swapped_flights:")
    print(swapped_flights)
    print(f"updated_flights_dict:")
    print(updated_flights_dict)
    print(f"final_environment_delayed:")
    print(final_environment_delayed)
    print(f"cancelled_flights:")
    print(cancelled_flights)
    print(f"current_datetime:")
    print(current_datetime)
    print(f"total_reward:")
    print(total_reward)
    print(f"step_num:")
    print(step_num)
    print(f"chosen_flight_actions:")
    print(chosen_flight_actions)
    return total_reward, step_num


seed = 42

SCENARIO_FOLDER = "../../data/Testing/6ac-700-diverse/mixed_medium_Scenario_029"

models = [
    # ("../trained_models/dqn/6ac-700-diverse/1023/proactive-93.zip", "proactive", "DQN Proactive-U"), 
    # ("../trained_models/dqn/6ac-700-diverse/1024/myopic-90.zip", "myopic", "DQN Proactive-N"),
    # ("../trained_models/dqn/6ac-700-diverse/1024/reactive-95.zip", "reactive", "DQN Reactive"),
    # ("../trained_models/dqn/6ac-100-superdiverse/111/drl-greedy-246.zip", "drl-greedy", "DQN Greedy-Guided"),





    # ("../trained_models/dqn/6ac-700-diverse/1025/myopic-91.zip", "myopic", "DQN Proactive-N"),
    ("../../trained_models/dqn/6ac-700-diverse/1025/proactive-94.zip", "proactive", "DQN Proactive-U"), 
    # ("../../trained_models/dqn/6ac-700-diverse/1025/reactive-97.zip", "reactive", "DQN Reactive"),
]



for model in models:
    MODEL_PATH = model[0]
    env_type = model[1]
    plot_title = model[2]
    print("*****"*10)
    print(f"Model Path: {MODEL_PATH}")
    print(f"Environment Type: {env_type}")
    if not os.path.exists(SCENARIO_FOLDER):
        raise FileNotFoundError(f"Scenario folder not found: {SCENARIO_FOLDER}")

    if not os.path.exists(MODEL_PATH):
        raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")

    # Run the fixed inference loop
    run_inference_dqn(MODEL_PATH, SCENARIO_FOLDER, env_type, seed, plot_title)

SyntaxError: invalid syntax (3694043956.py, line 198)