In [1]:
import os
import numpy as np
from stable_baselines3 import DQN
from src.environment import AircraftDisruptionEnv
from scripts.visualizations import StatePlotter
from scripts.utils import load_scenario_data, get_training_metadata, NumpyEncoder
from src.config import *
import re
import torch
import time
import ipywidgets as widgets
from IPython.display import display, clear_output, Image as IPImage
from io import BytesIO
import matplotlib.pyplot as plt

from scripts.logger import create_new_id, log_inference_metadata, log_inference_scenario_data, find_corresponding_training_id, get_config_variables, convert_to_serializable, update_id_status
import src.config as config
from datetime import datetime
import json

blabla_reward = {}

def run_inference_dqn_single(model_path, scenario_folder, env_type, seed):
    """
    Runs inference on a single scenario and logs detailed results.

    Args:
        model_path (str): Path to the trained model.
        scenario_folder (str): Path to the scenario folder.
        env_type (str): Type of environment ("myopic" or "proactive").
        seed (int): Seed for reproducibility.
    """
    # Load scenario data
    data_dict = load_scenario_data(scenario_folder)
    aircraft_dict = data_dict['aircraft']
    flights_dict = data_dict['flights']
    rotations_dict = data_dict['rotations']
    alt_aircraft_dict = data_dict['alt_aircraft']
    config_dict = data_dict['config']

    # Initialize environment
    env = AircraftDisruptionEnv(
        aircraft_dict, flights_dict, rotations_dict, alt_aircraft_dict, config_dict, env_type=env_type
    )

    # Load trained model and configure
    model = DQN.load(model_path)
    model.set_env(env)
    model.policy.set_training_mode(False)
    model.exploration_rate = 0.0

    # Set seed
    np.random.seed(seed)
    torch.manual_seed(seed)

    print(f"Seed: {seed}")

    # Initialize visualization
    state_plotter = StatePlotter(
        aircraft_dict=env.aircraft_dict,
        flights_dict=env.flights_dict,
        rotations_dict=env.rotations_dict,
        alt_aircraft_dict=env.alt_aircraft_dict,
        start_datetime=env.start_datetime,
        end_datetime=env.end_datetime,
        uncertain_breakdowns=env.uncertain_breakdowns,
    )

    obs, _ = env.reset()
    done_flag = False
    total_reward = 0
    step_num = 0
    max_steps = 1000

    # Scenario log
    scenario_log = {
        "scenario_folder": scenario_folder,
        "env_type": env_type,
        "seed": seed,
        "steps": [],
        "total_reward": 0,
        "runtime_start": datetime.utcnow().isoformat() + "Z",
    }

    while not done_flag and step_num < max_steps:
        # Step-level pre-action info
        step_info_before_action = {
            "num_cancelled_flights": len(env.cancelled_flights),
            "num_delayed_flights": len(env.environment_delayed_flights),
            "num_resolved_conflicts": len(env.resolved_conflicts),
            "current_datetime": env.current_datetime.isoformat(),
        }

        # Action selection
        action_mask = obs['action_mask']
        obs = {key: np.array(value, dtype=np.float32) for key, value in obs.items()}
        obs_tensor = model.policy.obs_to_tensor(obs)[0]
        q_values = model.policy.q_net(obs_tensor).detach().cpu().numpy().squeeze()
        masked_q_values = q_values.copy()
        masked_q_values[action_mask == 0] = -np.inf
        action = np.random.choice(np.where(action_mask == 1)[0])

        if action_mask[action] == 0:
            raise ValueError(f"Invalid action selected: {action}")

        # Environment step
        obs, reward, terminated, truncated, info = env.step(action)
        total_reward += reward
        done_flag = terminated or truncated
        action_mapped = env.map_index_to_action(action)

        # Log step data
        step_log = {
            "step_num": step_num,
            "action": action,
            "flight_action": action_mapped[0],
            "aircraft_action": action_mapped[1],
            "reward": reward,
            "total_reward": total_reward,
            "q_values": q_values.tolist(),
            "masked_q_values": masked_q_values.tolist(),
            "action_mask": action_mask.tolist(),
            "done_flag": done_flag,
            "info_after_step": convert_to_serializable(env.info_after_step),
            "step_info_before_action": step_info_before_action,
        }

        scenario_log["steps"].append(step_log)
        step_num += 1

    blabla_reward[scenario_folder] = total_reward
    # print(total_reward)

    scenario_log["total_reward"] = total_reward
    scenario_log["runtime_end"] = datetime.utcnow().isoformat() + "Z"

    # Log scenario data
    # log_inference_scenario_data(inference_id, scenario_log)

    print(f"Total Reward: {total_reward}")
    print(f"Steps Taken: {step_num}")
    return blabla_reward


def run_inference_dqn_folder(model_path, scenario_folder, env_type, seed):
    """
    Runs inference on all scenarios in a folder.

    Args:
        model_path (str): Path to the trained model.
        scenario_folder (str): Path to the folder containing scenarios.
        env_type (str): Type of environment ("myopic" or "proactive").
        seed (int): Seed for reproducibility.
    """
    # inference_id = create_new_id("inference")
    # inference_id = "20241217-1445"
    runtime_start = time.time()

    # Load training metadata
    # training_id = find_corresponding_training_id(model_path, env_type)
    # training_logs_path = f"../logs/training/training_{training_id}.json"
    # training_metadata = get_training_metadata(training_logs_path, env_type)

    # Metadata for inference
    # inference_metadata = {
    #     "inference_id": inference_id,
    #     "model_path": model_path,
    #     "scenario_folder": scenario_folder,
    #     "env_type": env_type,
    #     "seed": seed,
    #     "training_metadata": training_metadata,
    #     "runtime_start": datetime.utcnow().isoformat() + "Z",
    # }

    # log_inference_metadata(inference_id, inference_metadata)

    # complete_inference_log = {
    #     "inference_id": inference_id,
    #     "runtime_start": inference_metadata["runtime_start"],
    #     "runtime_end": None,
    #     "scenarios": {},
    # }

    for scenario in os.listdir(scenario_folder):
        scenario_path = os.path.join(scenario_folder, scenario)
        if os.path.isdir(scenario_path):
            blabla_reward = run_inference_dqn_single(model_path, scenario_path, env_type, seed)
            # complete_inference_log["scenarios"][scenario] = scenario_log

    # complete_inference_log["runtime_end"] = datetime.utcnow().isoformat() + "Z"
    runtime_end = time.time()
    print(f"Runtime: {runtime_end - runtime_start}")
    # log_file_path = os.path.join("../logs", "inference", f"inference_{inference_id}.json")

    # # Save complete inference log
    # with open(log_file_path, 'w') as log_file:
    #     json.dump(complete_inference_log, log_file, indent=4, cls=NumpyEncoder)

    # # Mark as done in ids.json
    # update_id_status(inference_id, "finished")

    # print(f"Inference log saved to {log_file_path}")
    return blabla_reward


# Main logic to run inference
latest = True
env_type = "myopic"

# if latest:
#     MODEL_PATH = f"../trained_models/dqn/{env_type}_3ac-{max(int(model.split('-')[1].split('.')[0]) for model in os.listdir('../trained_models/dqn') if model.startswith(f'{env_type}_3ac-'))}.zip"
# else:
#     MODEL_PATH = f"../trained_models/dqn/_perfect_{env_type}_3ac-2.zip"

# print(f"Model Path: {MODEL_PATH}")



seed = int(time.time())
SCENARIO_FOLDER = "../data/Locked/alpha/"

# PROACTIVE EXAMPLE
SCENARIO_FOLDER ="../data/Training/6ac-700-diverse/"

# NOT WORKING
# MODEL_PATH = "../trained_models/dqn/myopic_3ac-16.zip"

# ACTUALLY WORKING
MODEL_PATH = "../trained_models/dqn/6ac-700-diverse/23/myopic-training_13.zip"


print(f"Environment Type: {env_type}")

blabla_reward = run_inference_dqn_folder(MODEL_PATH, SCENARIO_FOLDER, env_type, seed)

# Calculate and print average reward
avg_reward = sum(blabla_reward.values()) / len(blabla_reward) if blabla_reward else 0
print(f"\nAverage Reward across all scenarios: {avg_reward:.2f}")


Environment Type: myopic
Seed: 1734444388

Calculating reward for action: flight 4, aircraft 6
  -0.0 penalty for 0 minutes of additional delay (capped at 2500)
  -0 penalty for 0 new cancelled flights: set()
  -0 penalty for inaction with remaining conflicts
  +86.80000000000001 bonus for proactive action (868.0 minutes ahead)
  -60.0 penalty for time progression
--------------------------------
Total reward: 26.8
--------------------------------
    prob: 0.37 is not nan, 0.0, or 1.0, for aircraft B737#5, so termination = False
checked and terminated: False

Calculating reward for action: flight 20, aircraft 6
  -2500 penalty for 1528.0 minutes of additional delay (capped at 2500)
  -0 penalty for 0 new cancelled flights: set()
  -0 penalty for inaction with remaining conflicts
  +10.5 bonus for proactive action (105.0 minutes ahead)
  -120.0 penalty for time progression
--------------------------------
Total reward: -2609.5
--------------------------------
    prob: 0.37 is not nan,

KeyboardInterrupt: 