<a href="https://colab.research.google.com/github/halduaij/S2DC-discrete/blob/main/s2dc_discrete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
import numpy as np
import tensorflow as tf
from collections import deque
import matplotlib.pyplot as plt
from copy import deepcopy
import tensorflow_probability as tfp
reward_p=[]

########################################
# 1. Utilities and Config
########################################

def set_seeds(seed_value=38):
    """Ensure reproducibility across various libraries."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    tf.random.set_seed(seed_value)


class Config:
    def __init__(self):
        # Core hyperparameters
        self.gamma = 0.975           # Discount factor
        self.batch_size = 256        # 1024*8
        self.buffer_capacity = 10000
        self.initial_learning_rate = 0.0001
        self.n = int(np.ceil(np.sqrt(500)))  # 23 if we want to keep your n^2 logic

        # Temperature parameters
        self.initial_tau = 10.0
        self.min_tau = 0.1
        self.tau_decay = 0.999
        self.max_tau = 20.0  # [IMPROVEMENT] - a maximum bound for adaptive tau

        # Network update parameters
        self.target_update_freq = 1
        self.soft_update_tau = 0.005

        # [IMPROVEMENT #1]: Additional hyperparams for difference networks in weighting
        self.use_n1_in_weights = True
        self.alpha = 0.1   # scale factor for incorporating N1 difference predictions

        # [IMPROVEMENT #2]: Conservative Q penalty factor (CQL-like idea)
        self.lambda_cql = 0.01

        # [IMPROVEMENT #3]: Adaptive tau if KL becomes too large
        self.kl_div_threshold = 50.0

    def get_tau(self, step: int) -> float:
        """Get temperature parameter with a decay schedule."""
        # We'll default to an exponential decay, but
        # the agent can override it adaptively if distribution is too peaked.
        tau_raw = max(self.min_tau,
                      self.initial_tau * (self.tau_decay ** (step // 1000)))
        return tau_raw


########################################
# 2. Network Architectures
########################################

def create_q_network(state_dim, action_dim):
    """Create a Q-network that maps a single integer state to action Q-values."""
    # For a grid of size n x n, state_dim = n, but we interpret states in [0, n^2 - 1].
    model = tf.keras.Sequential()
    # Embedding to handle discrete states:
    model.add(tf.keras.layers.Embedding(input_dim=env.observation_space.n, output_dim=64, input_length=1))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(256, activation='relu'))
    model.add(tf.keras.layers.Dense(128, activation='relu'))
    model.add(tf.keras.layers.Dense(64, activation='relu'))
    model.add(tf.keras.layers.Dense(env.action_space.n, activation='linear'))

    # Build the model with a sample input
    sample_input = tf.keras.Input(shape=(1,), dtype=tf.int32)
    model(sample_input)  # "call" once to build
    return model

def create_prediction_network(state_dim):
    """Create a network to predict differences (N1 or N2)."""
    model = tf.keras.Sequential([
        tf.keras.layers.Embedding(input_dim=env.observation_space.n, output_dim=64, input_length=2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(256, activation=None),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dense(128, activation=None),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dense(64, activation=None),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dense(1)  # outputs a single scalar difference
    ])

    # Build with a sample input [batch, 2]
    sample_input = tf.keras.Input(shape=(2,), dtype=tf.int32)
    model(sample_input)
    return model


########################################
# 3. The Reshuffle Agent
########################################

class ReshuffleAgent:
    def __init__(self, state_dim, action_dim, config: Config):
        self.config = config
        self.step_count = 0

        # tau_override is used if the KL exceeds a threshold
        self.tau_override = tf.Variable(0.0, dtype=tf.float32, trainable=False)

        # Primary networks
        self.Q_RS = create_q_network(state_dim, action_dim)
        self.Q_D  = create_q_network(state_dim, action_dim)

        # Prediction networks for differences
        self.N1   = create_prediction_network(state_dim)  # Q_RS_k - BQ_RS_(k-1)
        self.N2   = create_prediction_network(state_dim)  # Q_RS_k - Q_D_k

        # Target networks
        self.target_Q_RS = create_q_network(state_dim, action_dim)
        self.target_Q_D  = create_q_network(state_dim, action_dim)

        # "Old" networks for iteration k-1 references
        self.Q_RS_old     = create_q_network(state_dim, action_dim)
        self.Q_D_old      = create_q_network(state_dim, action_dim)
        self.target_Q_RS_old = create_q_network(state_dim, action_dim)
        self.target_Q_D_old  = create_q_network(state_dim, action_dim)

        # Sync initial weights
        self.sync_networks()
        self.sync_old_networks()

        # Create optimizers with a learning rate schedule
        lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=config.initial_learning_rate,
            decay_steps=10000,
            decay_rate=0.99,
            staircase=True
        )
        lr_schedule2 = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=config.initial_learning_rate,
            decay_steps=10000,
            decay_rate=0.99,
            staircase=True
        )
        self.optimizer_rs = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
        self.optimizer_d  = tf.keras.optimizers.Adam(learning_rate=lr_schedule2)
        self.optimizer_n1 = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
        self.optimizer_n2 = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    def sync_networks(self):
        """Hard update target networks from current networks."""
        self.target_Q_RS.set_weights(self.Q_RS.get_weights())
        self.target_Q_D.set_weights(self.Q_D.get_weights())

    def sync_old_networks(self):
        """Hard update the 'old' networks from the current networks."""
        self.Q_RS_old.set_weights(self.Q_RS.get_weights())
        self.Q_D_old.set_weights(self.Q_D.get_weights())
        self.target_Q_RS_old.set_weights(self.target_Q_RS.get_weights())
        self.target_Q_D_old.set_weights(self.target_Q_D.get_weights())

    def soft_update_networks(self):
        """Soft update target networks."""
        tau = self.config.soft_update_tau
        for target, source in zip(self.target_Q_RS.weights, self.Q_RS.weights):
            target.assign(tau * source + (1 - tau) * target)
        for target, source in zip(self.target_Q_D.weights, self.Q_D.weights):
            target.assign(tau * source + (1 - tau) * target)

    @tf.function
    def bellman_operator(self, rewards, next_states, dones, q_network, target_network):
        """Double Q-Learning Bellman operator."""
        # Step 1: Use online Q-network to SELECT the action
        online_q_values = q_network(next_states)  # [batch, action_dim]
        best_actions = tf.argmax(online_q_values, axis=1)  # [batch]

        # Step 2: Use target network to EVALUATE that action
        target_q_values = target_network(next_states)  # [batch, action_dim]
        next_q_values = tf.reduce_sum(
            target_q_values * tf.one_hot(best_actions, target_q_values.shape[-1]),
            axis=1
        )

        return rewards + self.config.gamma * next_q_values * (1 - dones)
    @tf.function
    def calculate_weights(self, states, next_states, actions,
                          q_rs_values, q_d_values, p_d_batch, q_diff_pred):
        """
        Returns the 'weights' for the Q_RS loss, reshuffled from p^D(s,a).
        p_d_batch: dataset distribution p^D(s,a) for each sample.
        q_diff_pred: predicted difference from N1 (or whichever you need).
        """

        # 1) Compute the log score (log_p_rs). For example:
        #    (Below is just an example "score" combining N1, N2, etc.)
        n1_pred = self.N1(tf.stack([states, next_states], axis=1))[:, 0]
        n2_pred = self.N2(tf.stack([states, next_states], axis=1))[:, 0]

        # Decide which tau to use (same as your existing logic)
        tau = tf.cond(
            self.tau_override > 0.0,
            lambda: self.tau_override,
            lambda: tf.cast(self.config.get_tau(self.step_count), tf.float32)
        )

        # Example of a log-score:
        #   log_p_rs = [   -(1 + |n2_pred|)/tau   +  log(|n1_pred|+1e-8)   ]
        #   (You can customize to your paper’s approach.)
        log_p_rs = tf.exp(- (1.0 +((n2_pred))) / tau )\
                  *tf.abs(n1_pred)
        q =(log_p_rs)  # shape [batch]

        # 3) Build the reshuffled distribution p_rs ~ p_d * q
        numerator = q
        denom = tf.reduce_mean(numerator) + 1e-8
        log_p_rs = numerator / denom  # This is your new distribution p_rs(s,a)

        # 2) Convert to unnormalized "preference" q via softmax
        #    NOTE: We typically do log_p_rs / tau again if needed,
        #    but you already included a /tau above. Adjust if desired.
        q = tf.nn.softmax(log_p_rs)  # shape [batch]

        # 3) Build the reshuffled distribution p_rs ~ p_d * q
        numerator = q *p_d_batch
        denom = tf.reduce_sum(numerator) + 1e-8
        p_rs = numerator / denom  # This is your new distribution p_rs(s,a)

        ratio_standard = p_rs / (p_d_batch + 1e-8)

        # 1) Compute 99th percentile of ratio
        #    (Requires e.g. 'tfp' or your own percentile code)
        percentile_q = 99.0
        # If you have tfp:
        # import tensorflow_probability as tfp
        # ratio_percentile_val = tfp.stats.percentile(ratio_standard, q=percentile_q)
        # Or roll your own percentile with tf.sort(...)

        # Quick approximate percentile:
        ratio_sorted = tf.sort(ratio_standard)
        idx_99 = tf.cast(tf.math.floor(0.99 * tf.cast(tf.size(ratio_sorted), tf.float32)), tf.int32)

        ratio_percentile_val = ratio_sorted[idx_99]

        # 2) Clip
        ratio_clipped = tf.minimum(ratio_standard, ratio_percentile_val)
        weights = ratio_clipped

        # Optional: track KL
        kl_div = tf.reduce_sum(
            p_rs * (  # Use p_rs instead of ratio_standard
                tf.math.log(p_rs + 1e-8)
                - tf.math.log(p_d_batch + 1e-8)
            )
        )

        return log_p_rs, kl_div, tau



    @tf.function
    def train_step(self, states, actions, rewards, next_states, dones, p_d):
        """One training step, referencing old Q-networks for difference computations."""
        self.step_count += 1

        # 1. Current Q-values
        q_rs = self.Q_RS(states)  # [batch, action_dim]
        q_d  = self.Q_D(states)   # [batch, action_dim]

        # For chosen actions
        q_rs_values = tf.reduce_sum(q_rs * tf.one_hot(actions, q_rs.shape[-1]), axis=1)
        q_d_values  = tf.reduce_sum(q_d  * tf.one_hot(actions, q_d.shape[-1]), axis=1)

        # 2. Bellman backups using OLD Q_RS
        bq_rs_old = self.bellman_operator(rewards, next_states, dones,
                                          self.Q_RS_old, self.target_Q_RS_old)

        # 3. Differences from N1, N2
        s_ns_stack = tf.stack([tf.squeeze(states, axis=1),
                               tf.squeeze(next_states, axis=1)], axis=1)
        n1_pred = self.N1(s_ns_stack)[:, 0]
        n2_pred = self.N2(s_ns_stack)[:, 0]

        # 4. p^D(s,a) for the batch
        states_squeezed = tf.squeeze(states, axis=1)
        batch_indices = states_squeezed * q_rs.shape[-1] + actions
        p_d_batch = tf.gather(p_d, batch_indices)

        # 5. Calculate final weights for Q_RS update
        weights, kl_div, actual_tau_used = self.calculate_weights(
            states_squeezed, tf.squeeze(next_states, axis=1),
            actions, q_rs_values, q_d_values, p_d_batch, n1_pred
        )

        # Adjust tau if KL is too large
        kl_div_threshold = self.config.kl_div_threshold
        new_tau = actual_tau_used
        self.tau_override.assign(new_tau)

        # 6. Compute Bellman backup for the CURRENT Q_RS
        bq_rs_current = self.bellman_operator(rewards, next_states, dones,
                                              self.Q_RS, self.target_Q_RS)

        # 7. Update Q_RS (weighted MSE + conservative penalty)
        with tf.GradientTape() as tape:
            q_rs_actions_current = tf.reduce_sum(self.Q_RS(states) *
                                                 tf.one_hot(actions, q_rs.shape[-1]), axis=1)
            # Weighted MSE
            bellman_loss = tf.reduce_mean(weights * tf.square(bq_rs_current - q_rs_actions_current))
            # Conservative penalty
            conservative_penalty = self.config.lambda_cql * tf.reduce_mean(tf.reduce_max(q_rs, axis=1))
            loss_rs = bellman_loss + conservative_penalty

        grads_rs = tape.gradient(loss_rs, self.Q_RS.trainable_variables)
        self.optimizer_rs.apply_gradients(zip(grads_rs, self.Q_RS.trainable_variables))

        # 8. Update Q_D (standard MSE)
        bq_d_current = self.bellman_operator(rewards, next_states, dones,
                                             self.Q_D, self.target_Q_D)
        with tf.GradientTape() as tape:
            q_d_actions_current = tf.reduce_sum(self.Q_D(states) *
                                                tf.one_hot(actions, q_d.shape[-1]), axis=1)
            loss_d = tf.reduce_mean(tf.square(bq_d_current - q_d_actions_current))
        grads_d = tape.gradient(loss_d, self.Q_D.trainable_variables)
        self.optimizer_d.apply_gradients(zip(grads_d, self.Q_D.trainable_variables))

        # 9. Update N1 => (Q_k^{RS}(s,a) - BQ_{k-1}^{RS}(s,a))
        current_q_rs_actions = tf.reduce_sum(self.Q_RS(states) *
                                             tf.one_hot(actions, q_rs.shape[-1]), axis=1)
        with tf.GradientTape() as tape:
            n1_target = current_q_rs_actions - bq_rs_old
            n1_values = self.N1(s_ns_stack)[:, 0]
            loss_n1 = tf.reduce_mean(tf.square(n1_target - n1_values))

        grads_n1 = tape.gradient(loss_n1, self.N1.trainable_variables)
        self.optimizer_n1.apply_gradients(zip(grads_n1, self.N1.trainable_variables))

        # 10. Update N2 => (Q_k^{RS}(s,a) - Q_k^D(s,a))
        current_q_d_actions = tf.reduce_sum(self.Q_D(states) *
                                            tf.one_hot(actions, q_d.shape[-1]), axis=1)
        with tf.GradientTape() as tape:
            n2_target = current_q_rs_actions - current_q_d_actions
            n2_values = self.N2(s_ns_stack)[:, 0]
            loss_n2 = tf.reduce_mean(tf.square(n2_target - n2_values))

        grads_n2 = tape.gradient(loss_n2, self.N2.trainable_variables)
        self.optimizer_n2.apply_gradients(zip(grads_n2, self.N2.trainable_variables))

        # 11. Soft update targets if needed
        if (self.step_count % self.config.target_update_freq) == 0:
            self.soft_update_networks()

        # -----------------------------------------------------
        # Extra metrics for debugging / plotting:
        #   - Bellman error for Q_RS
        #   - Value error between Q_RS and Q_D
        #   - Average weight for this batch
        # -----------------------------------------------------
        bellman_error_batch = tf.reduce_mean(tf.square(bq_rs_current - q_rs_actions_current))
        value_error_batch   = tf.reduce_mean(tf.square(q_rs_values - q_d_values))
        weight_mean_batch   = tf.reduce_mean(weights)

        return (loss_rs, loss_d, loss_n1, loss_n2, kl_div,
                bellman_error_batch, value_error_batch, weight_mean_batch)


    def estimate_p_d(self, states, actions, num_states, num_actions, alpha=0.01):
        """
        Estimates behavior policy distribution p_d(s,a) using empirical frequency.

        Args:
            states: Array of states from trajectories
            actions: Array of actions from trajectories
            num_states: Total number of possible states
            num_actions: Total number of possible actions
            alpha: Smoothing parameter (default: 0.01)
        """
        # Input validation
        if len(states) != len(actions):
            raise ValueError("States and actions must have same length")
        if alpha < 0:
            raise ValueError("Alpha must be non-negative")

        total_space = num_states * num_actions
        counts = np.zeros(total_space)

        # Get trajectory-aware counts
        indices = states * num_actions + actions
        unique_indices, visit_counts = np.unique(indices, return_counts=True)
        counts[unique_indices] = visit_counts

        # Compute total visits per state for conditional probability
        state_visits = np.zeros(num_states)
        for s in range(num_states):
            state_visits[s] = np.sum(counts[s * num_actions:(s + 1) * num_actions])

        # Add smoothing weighted by state visitation frequency
        state_freq = state_visits / (np.sum(state_visits) + 1e-8)
        for s in range(num_states):
            start_idx = s * num_actions
            end_idx = (s + 1) * num_actions
            # More smoothing for rarely visited states
            local_alpha = alpha * (1 + (1 - state_freq[s]))
            counts[start_idx:end_idx] += local_alpha

        # Normalize to get probability distribution
        p_d = counts / np.sum(counts)

        # Verify distribution properties
        assert np.all(p_d >= 0), "All probabilities must be non-negative"
        assert np.abs(np.sum(p_d) - 1.0) < 1e-6, "Probabilities must sum to 1"

        return p_d

########################################
# 4. Training Loop
########################################

def train_offline_rl(agent: ReshuffleAgent, dataset, num_epochs=250):
    """
    Trains the agent from a static offline dataset.
    dataset is a list of tuples: (state, action, reward, next_state, done).
    """
    dataset=deepcopy(dataset)
    # 1. Estimate p_d from the entire dataset (once)
    all_states  = np.array([s for (s, _, _, _, _) in dataset]).flatten()
    all_actions = np.array([a for (_, a, _, _, _) in dataset])
    p_d = agent.estimate_p_d(
        all_states, all_actions,
        agent.config.n**2,
        agent.Q_RS.output_shape[-1]
    )
    p_d = tf.constant(p_d, dtype=tf.float32)

    # -------------------------------------------------
    # Global lists to store step-level metrics.
    # We'll append to these every mini-batch (step).
    # -------------------------------------------------
    bellman_errors_by_step = []
    value_errors_by_step   = []
    weights_by_step        = []

    for epoch in range(num_epochs):
        # 2. Sync "old" networks at start of each epoch
        agent.sync_old_networks()
        if epoch%10==0:
          reward_p.append(test_learned_policy(num_test_episodes=4,model=agent.Q_RS,print_e=False))

        total_loss_rs = 0.0
        total_loss_d  = 0.0
        total_loss_n1 = 0.0
        total_loss_n2 = 0.0
        total_kl_div  = 0.0
        epoch_rewards = 0.0
        count_batches = 0

        # Shuffle dataset if desired
        random.shuffle(dataset)

        # 3. Process data in mini-batches
        for i in range(0, len(dataset), agent.config.batch_size):
            batch = dataset[i:i + agent.config.batch_size]
            states, actions, rewards, next_states, dones = zip(*batch)

            # Convert to tensors
            states_tf      = tf.convert_to_tensor(np.array(states).reshape(-1, 1), dtype=tf.int32)
            next_states_tf = tf.convert_to_tensor(np.array(next_states).reshape(-1, 1), dtype=tf.int32)
            actions_tf     = tf.convert_to_tensor(actions, dtype=tf.int32)
            rewards_tf     = tf.convert_to_tensor(rewards, dtype=tf.float32)
            dones_tf       = tf.convert_to_tensor(dones,   dtype=tf.float32)

            (loss_rs, loss_d, loss_n1, loss_n2, kl_div,
             bellman_err_b, value_err_b, weight_mean_b) = agent.train_step(
                states_tf, actions_tf, rewards_tf,
                next_states_tf, dones_tf, p_d
            )

            # Accumulate epoch-level sums
            total_loss_rs += loss_rs
            total_loss_d  += loss_d
            total_loss_n1 += loss_n1
            total_loss_n2 += loss_n2
            total_kl_div  += kl_div
            epoch_rewards += tf.reduce_sum(rewards_tf)
            count_batches += 1

            # ----------------------------------
            # Store these step-level metrics NOW
            # ----------------------------------
            bellman_errors_by_step.append(bellman_err_b.numpy())
            value_errors_by_step.append(value_err_b.numpy())
            weights_by_step.append(weight_mean_b.numpy())

        if count_batches > 0:
            avg_loss_rs = total_loss_rs / count_batches
            avg_loss_d  = total_loss_d / count_batches
            avg_loss_n1 = total_loss_n1 / count_batches
            avg_loss_n2 = total_loss_n2 / count_batches
            avg_kl      = total_kl_div / count_batches
        else:
            # No batches processed at all (rare, if dataset is non-empty)
            avg_loss_rs = 0
            avg_loss_d  = 0
            avg_loss_n1 = 0
            avg_loss_n2 = 0
            avg_kl      = 0

        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Loss Q_RS: {avg_loss_rs:.6f}, "
              f"Loss Q_D: {avg_loss_d:.6f}, "
              f"Loss N1: {avg_loss_n1:.6f}, "
              f"Loss N2: {avg_loss_n2:.6f}, "
              f"KL: {avg_kl:.6f}, "
              f"Total R: {epoch_rewards.numpy():.2f}")


    # ------------------------------------------------------------------
    # After training, we can optionally plot or save the step-level data
    # ------------------------------------------------------------------
    plt.figure(figsize=(16,4))

    # 1) Bellman Error by step
    plt.subplot(1,3,1)
    plt.plot(bellman_errors_by_step, label="Bellman Error per Step")
    plt.xlabel("Step")
    plt.ylabel("MSE")
    plt.title("Bellman Error (per mini-batch)")
    plt.legend()

    # 2) Value Error by step
    plt.subplot(1,3,2)
    plt.plot(value_errors_by_step, color='green', label="Value Error per Step")
    plt.xlabel("Step")
    plt.ylabel("MSE (Q_RS - Q_D)")
    plt.title("Value Error (per mini-batch)")
    plt.legend()

    # 3) Weights by step
    plt.subplot(1,3,3)
    plt.plot(weights_by_step, color='red', label="Mean Weights per Step")
    plt.xlabel("Step")
    plt.ylabel("Mean( p_RS / p_D )")
    plt.title("Reshuffling Weights (per mini-batch)")
    plt.legend()

    plt.tight_layout()
    plt.show()

    # If desired, return step-level lists so you can save them externally
    return bellman_errors_by_step, value_errors_by_step, weights_by_step


########################################
# 5. Evaluation (Optional)
########################################

def evaluate_agent(env, agent: ReshuffleAgent, num_episodes=50):
    """
    Evaluates the agent by acting greedily w.r.t. Q_RS.
    Returns average reward over num_episodes.
    """
    total_rewards = 0
    for _ in range(num_episodes):
        state = env.reset()
        done = False
        episode_reward = 0
        while not done:
            state_tensor = tf.convert_to_tensor([[state]], dtype=tf.int32)
            q_values = agent.Q_RS(state_tensor)  # shape [1, action_dim]
            action = tf.argmax(q_values[0]).numpy()
            next_state, reward, done, _ = env.step(action)
            episode_reward += reward
            state = next_state
        total_rewards += episode_reward
    return total_rewards / num_episodes


########################################
# 6. Usage Example
########################################
if __name__ == "__main__":
    config = Config()
    set_seeds()

    # Make sure dataset is provided
    if 'dataset' not in locals():
        raise ValueError("Dataset not defined. Please provide training dataset.")

    # Validate environment
    if 'env' not in locals():
        raise ValueError("Environment not defined. Please provide environment.")

    # Create agent
    agent = ReshuffleAgent(state_dim=env.observation_space.n, action_dim=env.action_space.n , config=config)

    # Compile models
    agent.Q_RS.compile(optimizer=agent.optimizer_rs, loss='mse')
    agent.Q_D.compile(optimizer=agent.optimizer_d, loss='mse')
    agent.N1.compile(optimizer=agent.optimizer_n1, loss='mse')
    agent.N2.compile(optimizer=agent.optimizer_n2, loss='mse')

    # Train agent
    with tf.device('/GPU:0'):
        bellman_steps, value_steps, weight_steps = train_offline_rl(agent, dataset, num_epochs=200)

    # Save model
    model_save_dir = "saved_models"
    os.makedirs(model_save_dir, exist_ok=True)
    agent.Q_RS.save(os.path.join(model_save_dir, "taxi_model_maxsteps_66.keras"))

    # Optionally evaluate
    # avg_reward = evaluate_agent(env, agent)
    # print("Average Evaluation Reward:", avg_reward)