In [None]:
import gymnasium as gym
import matplotlib.pyplot as plt
from KeyRef_2_env import KeyRef2_Env
from stable_baselines3.common.vec_env     import DummyVecEnv
from stable_baselines3.common.callbacks   import BaseCallback, EvalCallback
from stable_baselines3.common.monitor     import Monitor
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import DQN
import os
import datetime
import pandas as pd
from stable_baselines3.common.callbacks import BaseCallback

env = KeyRef2_Env()
action_list = ["CDR1", "CDR2", "CDR3", "CDR4", "CDR5", "CDR6"]
                          
# Create directories for models and logs
current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
models_dir = f"models/keyref2-{current_time}"
logdir = f"logs/keyref2-{current_time}"
log_training_txt_dir = "keyref2_log_training_txt"
log_training_excel_dir = "keyref2_log_training_excel"

if not os.path.exists(models_dir):
    os.makedirs(models_dir)
if not os.path.exists(logdir):
    os.makedirs(logdir)
if not os.path.exists(log_training_txt_dir):
    os.makedirs(log_training_txt_dir)
if not os.path.exists(log_training_excel_dir):
    os.makedirs(log_training_excel_dir)

# Generate unique file names based on current time
log_file          = os.path.join(log_training_txt_dir,   f"training_{current_time}.txt")
excel_file        = os.path.join(log_training_excel_dir, f"training_{current_time}.xlsx")
action_count_file = os.path.join(log_training_txt_dir,   f"action_count_{current_time}.txt")
action_excel_file = os.path.join(log_training_excel_dir, f"action_count_{current_time}.xlsx")

# Define the custom callback -------------------------------------------------------------
class CustomCallback(BaseCallback):
    def __init__(self, log_dir, excel_file, txt_file, action_count_file, action_excel_file, verbose=0):
        super(CustomCallback, self).__init__(verbose)
        self.log_dir = log_dir
        self.excel_file = excel_file
        self.txt_file = txt_file
        self.action_count_file = action_count_file
        self.action_excel_file = action_excel_file
        self.logs = []
        self.episode_rewards = []
        self.action_counts = {}
        self.episode_start = True

    def _on_training_start(self) -> None:
        # Initialize action counts
        self.action_counts = {action: 0 for action in action_list}

    def _on_step(self) -> bool:
        if self.episode_start:
            self.episode_rewards.append(0)
            self.episode_start = False

        # Record reward for the current step
        reward = self.locals['rewards'][0]
        self.episode_rewards[-1] += reward

        # Increment action count
        action = self.locals.get('actions', None)
        if action is not None:
            action_name = action_list[action[0]]
            self.action_counts[action_name] += 1
        
        return True

    def _on_rollout_end(self) -> None:
        # Called at the end of each episode
        sum_reward   = self.episode_rewards[-1] if self.episode_rewards else 0
        tardiness    = self.training_env.get_attr('unwrapped')[0].all_Tard
        
        self.logger.record('train/episode_reward',   sum_reward)
        self.logger.record('train/actual_tardiness', tardiness)
        

        self.logs.append({
            'episode': len(self.episode_rewards),
            'sum_reward': sum_reward,
            'tardiness': tardiness
        })

        self.episode_start = True

    
    def _on_training_end(self) -> None:
        # Save logs to Excel
        df = pd.DataFrame(self.logs)
        df.to_excel(self.excel_file, index=False)

        action_df = pd.DataFrame(list(self.action_counts.items()), columns=['Action', 'Count'])
        action_df.to_excel(self.action_excel_file, index=False)

        # Save logs to text file
        with open(self.txt_file, 'w') as f:
            f.write(df.to_string(index=False))
        with open(self.action_count_file, 'w') as f:
            f.write(action_df.to_string(index=False))

# Create the callback
callback = CustomCallback(log_dir=logdir, 
                          excel_file=excel_file,
                          txt_file=log_file,
                          action_count_file=action_count_file,
                          action_excel_file=action_excel_file,
                          verbose=1)

policy_kwargs = dict(
    net_arch=[128, 128, 128, 128, 128]  # 5 hidden layers with 128 neurons each
)
# Initialize the DQN model
model_path = os.path.join(models_dir, f"DQN_.zip")
model = DQN(
    'MlpPolicy', 
    env, 
    policy_kwargs=policy_kwargs, 
    buffer_size=1000,  # Replay buffer size N
    batch_size=32,  # Batch size
    gamma=0.9,  # Discount factor
    tau=0.01,  # Soft target update strategy
    exploration_initial_eps=0.5,  # Initial epsilon
    exploration_final_eps=0.1,  # Final epsilon
    verbose=1
)

