<a href="https://colab.research.google.com/github/halduaij/S2DC-discrete/blob/main/s2dc_continuous.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# First, install all required packages
!apt-get update -qq
!apt-get install -y python3-opengl > /dev/null
!apt-get install -y xvfb > /dev/null
!pip install gymnasium[mujoco] --quiet
!pip install pyvirtualdisplay --quiet
!pip install pyglet --quiet  # Additional dependency for rendering

# Import required libraries
import os
import sys
import time
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
from IPython.display import Image, clear_output

# Set up the virtual display first
print("Setting up virtual display...")
from pyvirtualdisplay import Display
virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

# Configure MuJoCo rendering
os.environ['MUJOCO_GL'] = 'egl'
os.environ['PYOPENGL_PLATFORM'] = 'egl'

# Import gymnasium after display setup
print("Importing gymnasium...")
import gymnasium as gym


In [None]:
import json
from datetime import datetime
import csv
from pathlib import Path


class CustomLogger:
    def __init__(self, log_dir="logs"):
        # Create timestamp for unique experiment folder
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.log_dir = Path(log_dir) / timestamp
        self.log_dir.mkdir(parents=True, exist_ok=True)

        # Define fieldnames for each file
        self.training_fields = [
            'step', 'loss_q_rs', 'loss_q_d', 'loss_n1', 'loss_n2',
            'kl_div', 'actor_loss', 'alpha_loss', 'alpha_val',
            'bellman_err','bellman_err2', 'value_err', 'weight_mean', 'tau_used'
        ]

        self.eval_fields = ['step', 'return_rs', 'return_d']

        # Initialize CSV files
        self.files = {
            'training': self._init_csv('training_metrics.csv', self.training_fields),
            'eval': self._init_csv('eval_metrics.csv', self.eval_fields)
        }

        # Save references to fieldnames
        self.fieldnames = {
            'training': self.training_fields,
            'eval': self.eval_fields
        }

    def _init_csv(self, filename, fieldnames):
        file = self.log_dir / filename
        with open(file, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=fieldnames)
            writer.writeheader()
        return file

    def log(self, metrics, step):
        metrics['step'] = step
        with open(self.files['training'], 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=self.fieldnames['training'])
            writer.writerow(metrics)

    def log_eval(self, eval_metrics, step):
        eval_metrics['step'] = step
        with open(self.files['eval'], 'a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=self.fieldnames['eval'])
            writer.writerow(eval_metrics)

    def save_config(self, config_dict):
        self.config = config_dict
        with open(self.log_dir / 'config.json', 'w') as f:
            json.dump(config_dict, f, indent=4)

    def finish(self):
        print(f"Logs saved to: {self.log_dir}")


In [None]:
import os
import random
import numpy as np
import pickle
import time
import wandb
import gym
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.mixture import GaussianMixture
from sklearn.decomposition import PCA
import tensorflow_probability as tfp

# ----------------------------------------------------
# 1. Config and Utilities
# ----------------------------------------------------
class Config:
    def __init__(self):
        # Core
        self.gamma = 0.99
        self.batch_size = 256
        self.buffer_capacity = int(1e6)
        self.init_alpha = 1 # Add initial alpha value
        self.q_lr = 3e-4
        # Slower policy learning
        self.policy_lr = 1e-4
        self.policy_eval_start = 0  # or any integer steps you want

        # Keep alpha learning rate low
        self.alpha_lr =3e-4
        # Network (unchanged)
        self.hidden_sizes = [256, 256]
        self.initial_learning_rate = 3e-4
        self.target_entropy = None  # set automatically: -act_dim
        self.init_alpha_rs = 1  # Initial alpha for RS policy
        self.init_alpha_d = 1   # Initial alpha for D policy
        # Warm-up logic
        self.warm_up_steps = 3000       # full uniform
        self.ramp_up_steps = 50000    # ramp from step=5000 -> step=15000
        self.temp = 1.0
        self.min_q_weight = 0.005  # varies by dataset type
        self.num_random = 10
        self.with_lagrange = False  # Set to True to use Lagrange version
        self.lagrange_thresh = 10.0
        # Temperature for reshuffling
        self.initial_tau = 5.0
        self.min_tau = 0.05
        self.tau_decay = 0.999  # Faster decay

        self.max_tau = 10.0
        self.kl_div_threshold = 3.0

        # Conservative penalty weight
        self.lambda_cql = 10

        # Target network updates
        self.target_update_freq = 1
        self.soft_update_tau = 0.005

        # Logging, steps
        self.max_gradient_steps = 35000
        self.eval_interval = 10000
        self.save_interval = 20000
    # Learning rate schedules

        # GMM config
        self.gmm_num_components = 20
        self.gmm_max_iter = 500

    def get_tau(self, step) -> float:
        tau_raw = tf.maximum(
            self.min_tau,
            self.initial_tau * (self.tau_decay ** tf.cast(step // 1000, tf.float32))
        )
        return tau_raw

def set_seeds(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    tf.random.set_seed(seed_value)

# ----------------------------------------------------
# 2. Network Architectures (unchanged)
# ----------------------------------------------------
def create_double_q_networks(state_dim, action_dim, hidden_sizes):
    """
    Creates two Q-networks with proper initialization and normalization.
    Uses orthogonal initialization and batch normalization.
    """
    s_in = keras.Input(shape=(state_dim,), dtype=tf.float32)
    a_in = keras.Input(shape=(action_dim,), dtype=tf.float32)
    concat_in = layers.Concatenate(axis=-1)([s_in, a_in])
    x1 = concat_in  # Remove BatchNorm here
    # Q1 Network
    for size in hidden_sizes:
        x1 = layers.Dense(
            size,
            kernel_initializer='glorot_uniform',  # Default Xavier/Glorot init
            bias_initializer='zeros'        )(x1)
        x1 = layers.ReLU()(x1)

    q1_out = layers.Dense(
        1,
        kernel_initializer='glorot_uniform',  # Default Xavier/Glorot init
        bias_initializer='zeros',
        name='q1_out'
    )(x1)
    q1_model = keras.Model([s_in, a_in], q1_out, name='Q1')
    x2 = concat_in  # Remove BatchNorm here
    # Q2 Network (separate network with same architecture)
    for size in hidden_sizes:
        x2 = layers.Dense(
            size,
            kernel_initializer='glorot_uniform',  # Default Xavier/Glorot init
            bias_initializer='zeros',
        )(x2)
        x2 = layers.ReLU()(x2)

    q2_out = layers.Dense(
        1,
        kernel_initializer='glorot_uniform',  # Default Xavier/Glorot init
        bias_initializer='zeros',
        name='q2_out'
    )(x2)
    q2_model = keras.Model([s_in, a_in], q2_out, name='Q2')

    return q1_model, q2_model

def create_policy_network(state_dim, action_dim, hidden_sizes, max_action):
    s_in = keras.Input(shape=(state_dim,), dtype=tf.float32)
    x = s_in

    # Hidden layers
    for size in hidden_sizes:
        x = layers.Dense(
            size,
            activation='relu',
            kernel_initializer=keras.initializers.HeUniform()  # Change to He init
        )(x)

    # Mean should start very small
    mean = layers.Dense(
        action_dim,
        kernel_initializer=keras.initializers.RandomUniform(-3e-3, 3e-3),  # Standard DDPG/SAC init
        bias_initializer='zeros'
    )(x)

    # Log std should start negative
    log_std = layers.Dense(
        action_dim,
        kernel_initializer=keras.initializers.RandomUniform(-3e-3, 3e-3),
    )(x)

    return keras.Model(inputs=s_in, outputs=[mean, log_std])


def create_diff_network(state_dim, action_dim, hidden_sizes, name):
    s_in = keras.Input(shape=(state_dim,), dtype=tf.float32)
    a_in = keras.Input(shape=(action_dim,), dtype=tf.float32)
    sn_in = keras.Input(shape=(state_dim,), dtype=tf.float32)

    # Normalize inputs
    concat_in = layers.Concatenate(axis=-1)([s_in, a_in, sn_in])

    x = concat_in
    # Hidden layers with proper initialization
    for size in hidden_sizes:
        x = layers.Dense(
            size,
            activation='relu',
            kernel_initializer='glorot_uniform',  # Default Xavier/Glorot init
            bias_initializer='zeros',
        )(x)

    # Small init for final layer
    diff_out = layers.Dense(
        1,
        activation=None,
        kernel_initializer=keras.initializers.orthogonal(0.01),
        bias_initializer='zeros'
    )(x)

    model = keras.Model([s_in, a_in, sn_in], diff_out, name=name)
    return model


# ----------------------------------------------------
# 3. The ReshuffleSACAgent (no architecture changes)
# ----------------------------------------------------
class ReshuffleSACAgent:
    def __init__(self, env, config: Config):
        self.env = env
        self.config = config
        if config.with_lagrange:
            self.log_alpha_prime_rs = tf.Variable(0.0)
            self.alpha_prime_optimizer_rs = tf.keras.optimizers.Adam(config.q_lr)
            self.log_alpha_prime_d = tf.Variable(0.0)
            self.alpha_prime_optimizer_d = tf.keras.optimizers.Adam(config.q_lr)
        obs_dim = env.observation_space.shape[0]
        act_dim = env.action_space.shape[0]
        self.max_action = float(env.action_space.high[0])
        self.state_dim = obs_dim
        self.action_dim = act_dim

        self.global_step = tf.Variable(0, dtype=tf.int64)

        # Critics
        self.Q_RS1, self.Q_RS2 = create_double_q_networks(obs_dim, act_dim, config.hidden_sizes)
        self.Q_D1, self.Q_D2   = create_double_q_networks(obs_dim, act_dim, config.hidden_sizes)
        self.target_Q_RS1, self.target_Q_RS2 = create_double_q_networks(obs_dim, act_dim, config.hidden_sizes)
        self.target_Q_D1,  self.target_Q_D2  = create_double_q_networks(obs_dim, act_dim, config.hidden_sizes)
        self.Q_RS1_old, self.Q_RS2_old = create_double_q_networks(obs_dim, act_dim, config.hidden_sizes)
        self.Q_D1_old,  self.Q_D2_old  = create_double_q_networks(obs_dim, act_dim, config.hidden_sizes)
        self.target_Q_RS1_old, self.target_Q_RS2_old = create_double_q_networks(obs_dim, act_dim, config.hidden_sizes)
        self.target_Q_D1_old,  self.target_Q_D2_old  = create_double_q_networks(obs_dim, act_dim, config.hidden_sizes)

        # Actor & difference networks
        self.actor = create_policy_network(obs_dim, act_dim, config.hidden_sizes, self.max_action)
        self.N1 = create_diff_network(obs_dim, act_dim, config.hidden_sizes, "N1")
        self.N2 = create_diff_network(obs_dim, act_dim, config.hidden_sizes, "N2")

        # Alpha
        if config.target_entropy is None:
         self.config.target_entropy = -6
        self.log_alpha_rs = tf.Variable(np.log(config.init_alpha_rs), dtype=tf.float32)
        self.log_alpha_d = tf.Variable(np.log(config.init_alpha_d), dtype=tf.float32)

        # Sync
        self.sync_networks_hard()
        self.sync_old_networks()

        # LR schedules
        lr_schedule = keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=config.initial_learning_rate,
            decay_steps=20000,
            decay_rate=0.99,
            staircase=True
        )



        # In ReshuffleSACAgent init:
        self.actor_d = create_policy_network(obs_dim, act_dim, config.hidden_sizes, self.max_action)

        self.q_rs1_opt = keras.optimizers.Adam(lr_schedule)
        self.q_rs2_opt = keras.optimizers.Adam(lr_schedule)
        self.q_d1_opt = keras.optimizers.Adam(lr_schedule)
        self.q_d2_opt = keras.optimizers.Adam(lr_schedule)

        self.actor_opt = keras.optimizers.Adam(lr_schedule)
        self.actor_d_opt = keras.optimizers.Adam(lr_schedule)

        self.alpha_rs_opt = keras.optimizers.Adam(config.alpha_lr)
        self.alpha_d_opt = keras.optimizers.Adam(config.alpha_lr)
        self.n1_opt    = keras.optimizers.Adam(lr_schedule)
        self.n2_opt    = keras.optimizers.Adam(lr_schedule)

        # Tau override
        self.tau_override = tf.Variable(0.0, dtype=tf.float32, trainable=False)
    @tf.function
    def compute_cql_penalty(
        self,
        q_network,
        states,
        next_states,
        actions,
        q_pred,            # Q(s, a_data)
        log_alpha_prime,
        sample_fn
    ):
        """
        A corrected version that avoids shape errors and closely
        follows the rlkit-style CQL (random + data + policy + next).
        """
        batch_size = tf.shape(states)[0]                # e.g. 256
        act_dim    = tf.shape(actions)[1]
        num_rand   = self.config.num_random            # e.g. 10

        # ------------------------------------------------------------
        # 1) Random actions (sampled from Uniform[-1,1]) for s
        # ------------------------------------------------------------
        # We'll tile 'states' so it has shape [batch_size*num_rand, obs_dim].
        repeated_states = tf.repeat(states, repeats=num_rand, axis=0)
        # random actions => shape [batch_size*num_rand, act_dim]
        random_actions = tf.random.uniform(
            shape=(batch_size * num_rand, act_dim),
            minval=-1.0,
            maxval=1.0
        )
        # Evaluate Q(s, a_rand)
        q_rand = q_network([repeated_states, random_actions])  # => [batch_size*num_rand, 1]
        q_rand = tf.reshape(q_rand, [batch_size, num_rand])    # => [batch_size, num_rand]

        # Optionally subtract "random density" if you want importance sampling
        random_density = tf.cast(act_dim, tf.float32) * tf.math.log(tf.constant(0.5, dtype=tf.float32))
        q_rand_adjusted = q_rand - random_density
        # but let's keep it as q_rand if you prefer

        # ------------------------------------------------------------
        # 2) Data action Q(s, a_data) is given by q_pred
        #    We'll keep it shaped [batch_size, 1].
        # ------------------------------------------------------------
        q_data = tf.reshape(q_pred, [batch_size, 1])

        # ------------------------------------------------------------
        # 3) Current policy actions from 'states'
        # ------------------------------------------------------------
        # We'll sample 'num_rand' times for each s => tile states again
        repeated_states2 = tf.repeat(states, repeats=num_rand, axis=0)
        curr_actions, curr_log_pis = sample_fn(repeated_states2)
        # Evaluate Q
        q_curr = q_network([repeated_states2, curr_actions])    # => [batch_size*num_rand, 1]
        # Reshape to [batch_size, num_rand]
        q_curr = tf.reshape(q_curr, [batch_size, num_rand])
        curr_log_pis = tf.reshape(curr_log_pis, [batch_size, num_rand])
        # We'll do (Q - log_pi)
        q_curr_adjusted = q_curr - curr_log_pis

        # ------------------------------------------------------------
        # 4) Next policy actions from next_states (tile them too),
        #    then evaluate Q(s, a_next). This matches rlkit's approach:
        #    next_actions come from pi(...| s'), but still feed into Q(s, a_next).
        # ------------------------------------------------------------
        repeated_next_states = tf.repeat(next_states, repeats=num_rand, axis=0)
        next_actions_samps, next_log_pis = sample_fn(repeated_next_states)
        # Evaluate Q(s, a_next) => we feed 'repeated_states' or 'repeated_states2'?
        # We'll use repeated_states here so it's Q(s, a_{next}).
        # Must be consistent in shape => repeated_states as well:
        q_next = q_network([repeated_next_states, next_actions_samps])
        q_next = tf.reshape(q_next, [batch_size, num_rand])         # => [batch_size, num_rand]
        next_log_pis = tf.reshape(next_log_pis, [batch_size, num_rand])
        q_next_adjusted = q_next - next_log_pis

        # ------------------------------------------------------------
        # 5) Concatenate all candidate Q-values: random, data, curr, next
        # ------------------------------------------------------------
        cat_q = tf.concat(
            [
                q_data,           # shape [batch_size, 1]
                q_rand_adjusted,           # or q_rand_adjusted
                q_curr_adjusted,
                q_next_adjusted
            ],
            axis=1
        ) # => shape [batch_size, 1 + 3*num_rand]

        # ------------------------------------------------------------
        # 6) Single log-sum-exp across that dimension => [batch_size]
        # ------------------------------------------------------------
        cql_logsumexp = tf.reduce_logsumexp(cat_q / self.config.temp, axis=1)
        cql_logsumexp = cql_logsumexp * self.config.temp

        # 7) Subtract data Q
        cql_loss = cql_logsumexp - tf.squeeze(q_data, axis=1)  # subtract shape [batch_size]

        # 8) Mean across batch
        cql_loss = tf.reduce_mean(cql_loss)

        # 9) Multiply by min_q_weight
        cql_loss *= self.config.min_q_weight

        # 10) If using Lagrange version
        if self.config.with_lagrange:
            alpha_prime = tf.clip_by_value(tf.exp(log_alpha_prime), 0.0, 1.0e6)
            lagrange_obj = cql_loss - self.config.lagrange_thresh
            cql_loss = alpha_prime * lagrange_obj
            alpha_prime_loss = -cql_loss  # we minimize => maximizing the negative
            return cql_loss, alpha_prime_loss

        return cql_loss


    def act_d(self, state, deterministic=False):
        state_tf = tf.convert_to_tensor(state.reshape(1,-1), dtype=tf.float32)
        mean, log_std = self.actor_d(state_tf)
        log_std = tf.clip_by_value(log_std, -20, 2)
        std = tf.exp(log_std)
        if deterministic:
            return (tf.tanh(mean) * self.max_action)[0].numpy()
        eps = tf.random.normal(shape=mean.shape)
        pre_tanh = mean + std * eps
        return (tf.tanh(pre_tanh) * self.max_action)[0].numpy()
    def sync_networks_hard(self):
        self.target_Q_RS1.set_weights(self.Q_RS1.get_weights())
        self.target_Q_RS2.set_weights(self.Q_RS2.get_weights())
        self.target_Q_D1.set_weights(self.Q_D1.get_weights())
        self.target_Q_D2.set_weights(self.Q_D2.get_weights())
    @tf.function
    def policy_diagnosis(self, states):
        """Diagnose policy behavior"""
        mean, log_std = self.actor(states)
        std = tf.exp(tf.clip_by_value(log_std, -20, 2))
        tf.print("\nPolicy Diagnostics:")
        tf.print("Mean range:", tf.reduce_min(mean), tf.reduce_max(mean))
        tf.print("Std range:", tf.reduce_min(std), tf.reduce_max(std))

        actions, _ = self.sample_actions_logp(states)
        tf.print("Action range:", tf.reduce_min(actions), tf.reduce_max(actions))

        q1 = self.Q_RS1([states, actions])
        q2 = self.Q_RS2([states, actions])
        tf.print("Q-values range:", tf.reduce_min(q1), tf.reduce_max(q1))

        return actions

    @tf.function
    def value_diagnosis(self, states, actions):
        """Diagnose value estimation"""
        tf.print("\nValue Diagnostics:")
        # Current Q-values
        q_rs1 = self.Q_RS1([states, actions])
        q_rs2 = self.Q_RS2([states, actions])
        q_d1 = self.Q_D1([states, actions])
        q_d2 = self.Q_D2([states, actions])

        tf.print("Q_RS1 stats:", tf.reduce_min(q_rs1), tf.reduce_mean(q_rs1), tf.reduce_max(q_rs1))
        tf.print("Q_RS2 stats:", tf.reduce_min(q_rs2), tf.reduce_mean(q_rs2), tf.reduce_max(q_rs2))
        tf.print("Q_D1 stats:", tf.reduce_min(q_d1), tf.reduce_mean(q_d1), tf.reduce_max(q_d1))
        tf.print("Q_D2 stats:", tf.reduce_min(q_d2), tf.reduce_mean(q_d2), tf.reduce_max(q_d2))

        # Target Q-values
        next_actions, next_logp = self.sample_actions_logp(states)
        target_q1 = self.target_Q_RS1([states, next_actions])
        target_q2 = self.target_Q_RS2([states, next_actions])

        tf.print("Target Q stats:", tf.reduce_min(target_q1), tf.reduce_mean(target_q1), tf.reduce_max(target_q1))
        tf.print("Next logp stats:", tf.reduce_min(next_logp), tf.reduce_mean(next_logp), tf.reduce_max(next_logp))

    def sync_old_networks(self):
        self.Q_RS1_old.set_weights(self.Q_RS1.get_weights())
        self.Q_RS2_old.set_weights(self.Q_RS2.get_weights())
        self.Q_D1_old.set_weights(self.Q_D1.get_weights())
        self.Q_D2_old.set_weights(self.Q_D2.get_weights())
        self.target_Q_RS1_old.set_weights(self.target_Q_RS1.get_weights())
        self.target_Q_RS2_old.set_weights(self.target_Q_RS2.get_weights())
        self.target_Q_D1_old.set_weights(self.target_Q_D1.get_weights())
        self.target_Q_D2_old.set_weights(self.target_Q_D2.get_weights())

    def soft_update(self, source_vars, target_vars):
        tau = self.config.soft_update_tau
        for s_var, t_var in zip(source_vars, target_vars):
            t_var.assign(tau * s_var + (1 - tau) * t_var)

    def soft_update_all(self):
        self.soft_update(self.Q_RS1.trainable_variables, self.target_Q_RS1.trainable_variables)
        self.soft_update(self.Q_RS2.trainable_variables, self.target_Q_RS2.trainable_variables)
        self.soft_update(self.Q_D1.trainable_variables,  self.target_Q_D1.trainable_variables)
        self.soft_update(self.Q_D2.trainable_variables,  self.target_Q_D2.trainable_variables)

    def act(self, state, deterministic=False):
        state_tf = tf.convert_to_tensor(state.reshape(1,-1), dtype=tf.float32)
        mean, log_std = self.actor(state_tf)
        log_std = tf.clip_by_value(log_std, -20, 2)
        std = tf.exp(log_std)
        if deterministic:
            return (tf.tanh(mean) * self.max_action)[0].numpy()
        eps = tf.random.normal(shape=mean.shape)
        pre_tanh = mean + std * eps
        return (tf.tanh(pre_tanh) * self.max_action)[0].numpy()

    @tf.function
    def sample_actions_logp(self, states):
        mean, log_std = self.actor(states)
        std = tf.exp(tf.clip_by_value(log_std, -20, 2))

        # Create normal distribution
        distribution = tfp.distributions.Normal(mean, std)

        # Sample raw actions
        raw_actions = distribution.sample()

        # Squash actions and calculate log probs
        actions = tf.tanh(raw_actions)
        log_prob = distribution.log_prob(raw_actions)

        # Apply tanh correction
        log_prob -= tf.reduce_sum(tf.math.log(1.0 - tf.square(actions) + 1e-6), axis=1, keepdims=True)
        log_prob = tf.reduce_sum(log_prob, axis=1, keepdims=True)

        return actions * self.max_action, log_prob

    # Same for the D policy
    @tf.function
    def sample_actions_logp_d(self, states):
        mean, log_std = self.actor_d(states)
        std = tf.exp(tf.clip_by_value(log_std, -20, 2))

        distribution = tfp.distributions.Normal(mean, std)
        raw_actions = distribution.sample()
        actions = tf.tanh(raw_actions)
        log_prob = distribution.log_prob(raw_actions)
        log_prob -= tf.reduce_sum(tf.math.log(1.0 - tf.square(actions) + 1e-6), axis=1, keepdims=True)
        log_prob = tf.reduce_sum(log_prob, axis=1, keepdims=True)

        return actions * self.max_action, log_prob
    # GMM + PCA for p^D
    def fit_behavior_gmm(self, all_states, all_actions, pca_dim=15):
        total_samples = all_states.shape[0]
        subset_size = min(100000, total_samples)
        idxes = np.random.choice(total_samples, size=subset_size, replace=False)
        states_subset = all_states[idxes]
        actions_subset = all_actions[idxes]

        X = np.concatenate([states_subset, actions_subset], axis=1)
        self.pca = PCA(n_components=pca_dim)
        X_reduced = self.pca.fit_transform(X)
        self.gmm = GaussianMixture(
            n_components=self.config.gmm_num_components,
            max_iter=self.config.gmm_max_iter,
            covariance_type='full',
            verbose=2
        )
        start = time.time()
        self.gmm.fit(X_reduced)
        print(f"GMM fit done in {time.time()-start:.1f} sec. Using {subset_size} samples.")

    def eval_p_d(self, states, actions):
        X = np.concatenate([states, actions], axis=-1)
        X_reduced = self.pca.transform(X)
        logp = self.gmm.score_samples(X_reduced)
        return np.float32(np.exp(logp))

    @tf.function
    def sac_backup(self, next_states, rewards, dones, target_q1, target_q2, alpha_val):
        next_acts, next_logp = self.sample_actions_logp(next_states)
        q1_next = target_q1([next_states, next_acts])
        q2_next = target_q2([next_states, next_acts])
        min_q_next = tf.minimum(q1_next, q2_next)
        backup = rewards + self.config.gamma * (1.0 - dones) * (min_q_next - alpha_val  * next_logp)
        return backup
    @tf.function
    def sac_backup_d(self, next_states, rewards, dones, target_q1, target_q2, alpha_val):
        # now sample from actor_d instead of actor
        next_acts_d, next_logp_d = self.sample_actions_logp_d(next_states)
        q1_next = target_q1([next_states, next_acts_d])
        q2_next = target_q2([next_states, next_acts_d])
        min_q_next = tf.minimum(q1_next, q2_next)
        backup = rewards + self.config.gamma * (1.0 - dones) * (min_q_next - alpha_val  * next_logp_d)
        return backup

    def partial_blend_weights(self, uniform_w, discor_w, step_int):
        # More gradual ramp: from warm_up to warm_up + ramp_up
        warm = float(self.config.warm_up_steps)
        ramp = float(self.config.ramp_up_steps)
        step_f = tf.cast(step_int, tf.float32)
        start_ramp = warm
        end_ramp = warm + ramp  # e.g. 5000->15000

        alpha = tf.clip_by_value((step_f - start_ramp)/(end_ramp - start_ramp), 0.0, 1.0)
        return discor_w * alpha + uniform_w * (1 - alpha)

    @tf.function
    def calculate_weights(self, q_rs_vals, q_d_vals, p_d_batch, n1_pred, n2_pred):
        """
        p_d_vals here are the raw GMM densities for the minibatch samples,
        already properly normalized from the full GMM distribution
        """
        tau = tf.cond(
            self.tau_override > 0.0,
            lambda: self.tau_override,
            lambda: tf.cast(self.config.get_tau(self.global_step), tf.float32)
        )

        # Calculate log_p_rs as before
        softened_n2 = 0.5 + 0.5 * tf.abs(n2_pred)
        log_p_rs = -softened_n2 / tau + tf.math.log(tf.abs(n1_pred) + 1e-8)
        # Get unnormalized p_rs
        log_p_rs=tf.abs(log_p_rs)
        numerator = log_p_rs
        denom =  tfp.stats.percentile(numerator , 50.0)+ 1e-8
        p_rs = numerator / denom  # This is your new distribution p_rs(s,a)

                      # 1) Compute 99th percentile of ratio
        #    (Requires e.g. 'tfp' or your own percentile code)
        percentile_q = 99.0
        # If you have tfp:
        # import tensorflow_probability as tfp
        ratio_percentile_val = tfp.stats.percentile(p_rs, q=99)
        # Or roll your own percentile with tf.sort(...)

        # Quick approximate percentile:


        # 2) Clip
        ratio_clipped = tf.minimum(p_rs, ratio_percentile_val)
        ratio_percentile_val = tfp.stats.percentile(p_rs, q=1)

        ratio_clipped = tf.maximum(ratio_clipped, ratio_percentile_val)

        weights = ratio_clipped

        # Optional: track KL
        kl_div = tf.reduce_sum(
            p_rs * (  # Use p_rs instead of ratio_standard
                tf.math.log(p_rs + 1e-8)
                - tf.math.log(p_d_batch + 1e-8)
            )
        )

        return weights, kl_div, tau


    # Add to training loop:


    @tf.function
    def get_log_prob_of_data(self,actor, states, data_actions, max_action):
        """
        Compute log pi(a_data | s) for your Tanh-Gaussian policy 'actor'.

        :param actor: e.g. self.actor
        :param states: shape [batch, obs_dim]
        :param data_actions: shape [batch, act_dim],
                            assumed in [-max_action, max_action].
        :param max_action: float (e.g. self.max_action)

        :return: log_probs, shape [batch, 1]
        """
        # 1) 'inverse tanh' for data_actions in [-max_action, max_action].
        #    If your environment always has actions in [-1,1],
        #    you can do data_actions_clamped = tf.clip_by_value(data_actions/max_action, -0.999, 0.999)
        #    or just assume they’re in [-1,1].
        scaled_actions = data_actions / max_action
        # Tanh is defined for input in (-1,1). So clamp slightly:
        clipped_acts = tf.clip_by_value(scaled_actions, -0.999999, 0.999999)
        raw_a_data = 0.5 * tf.math.log((1.0 + clipped_acts)/(1.0 - clipped_acts))  # atanh

        # 2) Get policy's (mu, log_std) for the states
        mean, log_std = actor(states)
        log_std = tf.clip_by_value(log_std, -20, 2)
        std = tf.exp(log_std)

        # 3) Evaluate normal pdf at raw_a_data
        distribution = tfp.distributions.Normal(loc=mean, scale=std)
        # log_prob for each dim => shape [batch, act_dim]
        logp_each_dim = distribution.log_prob(raw_a_data)

        # 4) Tanh correction: we must subtract log(1 - tanh^2(...))
        #   But we want log(1 - (tanh(raw_a))^2)
        #   raw_a_data was exactly "the inside" of tanh
        # So:
        log_det_jac = tf.reduce_sum(
            tf.math.log(1.0 - tf.square(tf.tanh(raw_a_data)) + 1e-6),
            axis=1, keepdims=True
        )

        # 5) Sum across action dims => shape [batch, 1]
        log_prob = tf.reduce_sum(logp_each_dim, axis=1, keepdims=True)

        # 6) Subtract the Jacobian
        log_prob = log_prob - log_det_jac

        return log_prob

    @tf.function
    def train_step(self, states, actions, rewards, next_states, dones, p_d_vals):
        alpha_rs_val = tf.exp(self.log_alpha_rs)
        alpha_d_val = tf.exp(self.log_alpha_d)

        with tf.GradientTape(persistent=True) as tape:
            # Evaluate Q
            q_rs1_vals = self.Q_RS1([states, actions])[:, 0]
            q_rs2_vals = self.Q_RS2([states, actions])[:, 0]
            q_d1_vals  = self.Q_D1([states, actions])[:, 0]
            q_d2_vals  = self.Q_D2([states, actions])[:, 0]
            q_rs_vals  = tf.minimum(q_rs1_vals, q_rs2_vals)
            q_d_vals   = tf.minimum(q_d1_vals,  q_d2_vals)

            n1_pred = self.N1([states, actions, next_states])[:, 0]
            n2_pred = self.N2([states, actions, next_states])[:, 0]

            # Uniform vs DisCor
            batch_size = tf.shape(q_rs_vals)[0]
            uniform_w = tf.ones((batch_size,), dtype=tf.float32)
            discor_w, kl_div, tau_used = self.calculate_weights(q_rs_vals, q_d_vals, p_d_vals[:,0], n1_pred, n2_pred)

            step_int = self.global_step
            def combined_weights():
                return self.partial_blend_weights(uniform_w, discor_w, step_int), kl_div, tau_used

            w_cond1 = (step_int < self.config.warm_up_steps)
            weights, kl_final, tau_final = tf.cond(
                w_cond1,
                lambda: (uniform_w, tf.constant(0.0, dtype=tf.float32), tf.constant(0.0, dtype=tf.float32)),
                lambda: combined_weights()
            )

            # If KL is large, gradually raise tau


            # SAC backups
            bq_rs1 = self.sac_backup(next_states, rewards, dones, self.target_Q_RS1, self.target_Q_RS2, alpha_rs_val)
            bq_rs2 = bq_rs1
            bq_d1  = self.sac_backup_d(next_states, rewards, dones, self.target_Q_D1,  self.target_Q_D2,  alpha_d_val)
            bq_d2  = bq_d1

            old_q_rs1 = self.Q_RS1_old([states, actions])[:, 0]
            old_q_rs2 = self.Q_RS2_old([states, actions])[:, 0]
            old_q_rs  = tf.minimum(old_q_rs1, old_q_rs2)
            bq_rs_old = self.sac_backup(next_states, rewards, dones, self.target_Q_RS1_old, self.target_Q_RS2_old, alpha_rs_val)

            # Weighted MSE for Q_RS
            bellman_loss_rs1 = tf.nn.compute_average_loss(
                weights * tf.square(bq_rs1[:,0] - q_rs1_vals)
            )
            bellman_loss_rs2 = tf.nn.compute_average_loss(
                weights * tf.square(bq_rs2[:,0] - q_rs2_vals)
            )

            cpen_rs1 = self.compute_cql_penalty(
                self.Q_RS1, states, next_states, actions, q_rs1_vals,
                None,  # not used
                self.sample_actions_logp
            )
            cpen_rs2 = self.compute_cql_penalty(
                self.Q_RS2, states, next_states, actions, q_rs2_vals,
                None,
                self.sample_actions_logp
            )
            # no alpha_prime_loss => so just 0.0 or skip
            cpen_d1 = self.compute_cql_penalty(
                self.Q_D1, states, next_states, actions, q_d1_vals,
                None,
                self.sample_actions_logp_d
            )
            cpen_d2 = self.compute_cql_penalty(
                self.Q_D2, states, next_states, actions, q_d2_vals,
                None,
                self.sample_actions_logp_d
            )

            loss_q_rs1 = bellman_loss_rs1 + cpen_rs1
            loss_q_rs2 = bellman_loss_rs2 + cpen_rs2
            total_q_rs_loss = loss_q_rs1 + loss_q_rs2

            # Q_D MSE
            bellman_loss_d1 = tf.nn.compute_average_loss(
                tf.square(bq_d1[:,0] - q_d1_vals)
            )
            bellman_loss_d2 = tf.nn.compute_average_loss(
                tf.square(bq_d2[:,0] - q_d2_vals)
            )
            loss_q_d = bellman_loss_d1 + bellman_loss_d2+cpen_d2+cpen_d1

            # N1 => Q_RS - BQ_RS_old
            n1_target = q_rs_vals - bq_rs_old[:,0]
            loss_n1 = tf.reduce_mean(tf.square(n1_target - n1_pred))

            # N2 => Q_RS - Q_D
            n2_target = q_rs_vals - q_d_vals
            loss_n2 = tf.reduce_mean(tf.square(n2_target - n2_pred))

        # ---- Gradient Clipping: Q and difference networks ----
        # Critic Q_RS
        q_rs_vars1 = self.Q_RS1.trainable_variables
        q_rs_vars2 = self.Q_RS2.trainable_variables
        q_rs_grads = tape.gradient(total_q_rs_loss, q_rs_vars1 + q_rs_vars2)
        q_rs_grads, _ = tf.clip_by_global_norm(q_rs_grads, 10.0)  # clip norm=10
        self.q_rs1_opt.apply_gradients(zip(q_rs_grads[:len(q_rs_vars1)], q_rs_vars1))
        self.q_rs2_opt.apply_gradients(zip(q_rs_grads[len(q_rs_vars1):], q_rs_vars2))

        # Critic Q_D
        q_d_vars1 = self.Q_D1.trainable_variables
        q_d_vars2 = self.Q_D2.trainable_variables
        q_d_grads = tape.gradient(loss_q_d, q_d_vars1 + q_d_vars2)
        q_d_grads, _ = tf.clip_by_global_norm(q_d_grads, 10.0)
        self.q_d1_opt.apply_gradients(zip(q_d_grads[:len(q_d_vars1)], q_d_vars1))
        self.q_d2_opt.apply_gradients(zip(q_d_grads[len(q_d_vars1):], q_d_vars2))

        # Difference networks
        n1_vars = self.N1.trainable_variables
        n2_vars = self.N2.trainable_variables
        n1_grads = tape.gradient(loss_n1, n1_vars)
        n1_grads, _ = tf.clip_by_global_norm(n1_grads, 10.0)
        self.n1_opt.apply_gradients(zip(n1_grads, n1_vars))

        n2_grads = tape.gradient(loss_n2, n2_vars)
        n2_grads, _ = tf.clip_by_global_norm(n2_grads, 10.0)
        self.n2_opt.apply_gradients(zip(n2_grads, n2_vars))

        del tape
        # Suppose we store the current global training step in self.global_step
        step_int = self.global_step

        # Define placeholders for logging
        actor_loss_rs = 0.0
        actor_loss_d  = 0.0
        alpha_rs_loss = 0.0
        alpha_d_loss  = 0.0
        if step_int >= self.config.policy_eval_start:
            # ============= Normal CQL/SAC-style actor update =============
            with tf.GradientTape(persistent=True) as actor_tape:
                # RS policy
                new_actions, logp = self.sample_actions_logp(states)   # shape [B, act_dim], [B,1]
                q_rs1_pi = self.Q_RS1([states, new_actions])           # shape [B,1]
                q_rs2_pi = self.Q_RS2([states, new_actions])
                min_q_rs_pi = tf.minimum(q_rs1_pi, q_rs2_pi)

                actor_loss_rs = tf.reduce_mean(alpha_rs_val * logp - min_q_rs_pi)

                # D policy
                new_actions_d, logp_d = self.sample_actions_logp_d(states)
                q_d1_pi = self.Q_D1([states, new_actions_d])
                q_d2_pi = self.Q_D2([states, new_actions_d])
                min_q_d_pi = tf.minimum(q_d1_pi, q_d2_pi)

                actor_loss_d = tf.reduce_mean(alpha_d_val * logp_d - min_q_d_pi)

            actor_rs_grads = actor_tape.gradient(actor_loss_rs, self.actor.trainable_variables)
            actor_d_grads  = actor_tape.gradient(actor_loss_d,  self.actor_d.trainable_variables)

            actor_rs_grads, _ = tf.clip_by_global_norm(actor_rs_grads, 10.0)
            actor_d_grads, _  = tf.clip_by_global_norm(actor_d_grads, 10.0)

            self.actor_opt.apply_gradients(zip(actor_rs_grads, self.actor.trainable_variables))
            self.actor_d_opt.apply_gradients(zip(actor_d_grads, self.actor_d.trainable_variables))
            del actor_tape

            # Now do alpha (temperature) updates
            with tf.GradientTape() as alpha_rs_tape:
                # typical: alpha_rs_loss = - E[ log_alpha_rs * (logp + target_entropy ) ]
                alpha_rs_losses = -1.0 * (self.log_alpha_rs * tf.stop_gradient(logp + self.config.target_entropy))
                alpha_rs_loss   = tf.nn.compute_average_loss(alpha_rs_losses)

            alpha_rs_grads = alpha_rs_tape.gradient(alpha_rs_loss, [self.log_alpha_rs])
            self.alpha_rs_opt.apply_gradients(zip(alpha_rs_grads, [self.log_alpha_rs]))

            # same for alpha_d
            with tf.GradientTape() as alpha_d_tape:
                alpha_d_losses = -1.0 * (self.log_alpha_d * tf.stop_gradient(logp_d + self.config.target_entropy))
                alpha_d_loss   = tf.nn.compute_average_loss(alpha_d_losses)

            alpha_d_grads = alpha_d_tape.gradient(alpha_d_loss, [self.log_alpha_d])
            self.alpha_d_opt.apply_gradients(zip(alpha_d_grads, [self.log_alpha_d]))

        else:
            # ============= Early phase: do BC objective =============
            # For demonstration, let's do BC for both RS policy & D policy
            # We have the offline dataset's (states, actions).
            with tf.GradientTape(persistent=True) as bc_tape:
                logprob_data_rs = self.get_log_prob_of_data(self.actor, states, actions, self.max_action)
                bc_loss_rs      = -tf.reduce_mean(logprob_data_rs)  # negative log-likelihood

                logprob_data_d  = self.get_log_prob_of_data(self.actor_d, states, actions, self.max_action)
                bc_loss_d       = -tf.reduce_mean(logprob_data_d)

            # If you want, you can add alpha * logp_data to bc_loss. Redwood-lab sometimes
            # does: bc_loss = alpha * logp - log_prob_data, etc.
            # But pure BC might be simpler.

            bc_rs_grads = bc_tape.gradient(bc_loss_rs, self.actor.trainable_variables)
            bc_d_grads  = bc_tape.gradient(bc_loss_d,  self.actor_d.trainable_variables)

            bc_rs_grads, _ = tf.clip_by_global_norm(bc_rs_grads, 10.0)
            bc_d_grads, _  = tf.clip_by_global_norm(bc_d_grads, 10.0)

            self.actor_opt.apply_gradients(zip(bc_rs_grads, self.actor.trainable_variables))
            self.actor_d_opt.apply_gradients(zip(bc_d_grads, self.actor_d.trainable_variables))
            del bc_tape


        # Soft update
        if (self.global_step % self.config.target_update_freq) == 0:
            self.soft_update_all()

        bellman_err = 0.5*(bellman_loss_rs1 + bellman_loss_rs2)
        bellman_err2 = 0.5*(bellman_loss_d1 + bellman_loss_d2)

        value_err = tf.reduce_mean(tf.square(q_rs_vals - q_d_vals))
        w_mean = tf.reduce_mean(weights)
        self.global_step.assign_add(1)
        alpha_loss=alpha_rs_loss
        return {
            "loss_q_rs": total_q_rs_loss,
            "loss_q_d": loss_q_d,
            "loss_n1": loss_n1,
            "loss_n2": loss_n2,
            "kl_div": kl_final,
            "actor_loss": actor_loss_rs,
            "alpha_loss": alpha_loss,
            "alpha_val": alpha_rs_val,
            "bellman_err": bellman_err,
            "bellman_err2": bellman_err2,
            "value_err": value_err,
            "weight_mean": w_mean,
            "tau_used": tau_final
        }

# ----------------------------------------------------
# 4. Training Loop
# ----------------------------------------------------
def train_offline_sac_reshuffle(agent,env_id="HalfCheetah-v4",
                                seed=42,
                                max_steps=35000,
                                config=None,
                                log_wandb=False):
    if config is None:
        config = Config()
    set_seeds(seed)
    # Initialize custom logger
    logger = CustomLogger()

    # Log configuration
    config_dict = {
        "env_id": env_id,
        "seed": seed,
        "max_steps": max_steps,
        "gamma": config.gamma,
        "batch_size": config.batch_size,
        "buffer_capacity": config.buffer_capacity,
        "q_lr": config.q_lr,
        "policy_lr": config.policy_lr,
        "alpha_lr": config.alpha_lr,
        "hidden_sizes": config.hidden_sizes,
        "warm_up_steps": config.warm_up_steps,
        "ramp_up_steps": config.ramp_up_steps,
        "initial_tau": config.initial_tau,
        "min_tau": config.min_tau,
        "tau_decay": config.tau_decay,
        "lambda_cql": config.lambda_cql
    }
    logger.save_config(config_dict)
    # new_step_api => remove old warnings
    env = gym.make(env_id)

    # Load offline dataset
    with open('/content/drive/MyDrive/halfcheetah_medium_dataset.pkl', 'rb') as f:
        dataset = pickle.load(f)

    states_np = dataset['observations']
    actions_np = dataset['actions']
    rewards_np = dataset['rewards'].reshape(-1,1)


    next_states_np = dataset['next_observations']
    dones_np = dataset['terminals'].reshape(-1,1)
    N = states_np.shape[0]
    states_tf = tf.constant(dataset['observations'], dtype=tf.float32)
    actions_tf = tf.constant(dataset['actions'], dtype=tf.float32)
    rewards_tf = tf.constant(dataset['rewards'].reshape(-1,1), dtype=tf.float32)
    next_states_tf = tf.constant(dataset['next_observations'], dtype=tf.float32)
    dones_tf = tf.constant(dataset['terminals'].reshape(-1,1), dtype=tf.float32)

    # Create TF dataset
    dataset = tf.data.Dataset.from_tensor_slices((
        states_tf, actions_tf, rewards_tf, next_states_tf, dones_tf
    )).shuffle(buffer_size=int(1e6)).batch(config.batch_size).prefetch(tf.data.AUTOTUNE)

    agent = ReshuffleSACAgent(env, config)

    # Fit GMM => p^D(s,a)

    if log_wandb:

      wandb_config = {
          "env_id": env_id,
          "seed": seed,
          "max_steps": max_steps,
          "gamma": config.gamma,
          "batch_size": config.batch_size,
          "buffer_capacity": config.buffer_capacity,
          "q_lr": config.q_lr,
          "policy_lr": config.policy_lr,
          "alpha_lr": config.alpha_lr,
          "hidden_sizes": config.hidden_sizes,
          "warm_up_steps": config.warm_up_steps,
          "ramp_up_steps": config.ramp_up_steps,
          "initial_tau": config.initial_tau,
          "min_tau": config.min_tau,
          "tau_decay": config.tau_decay,
          "lambda_cql": config.lambda_cql
      }



      wandb.init(
      project="Offline-SAC-Reshuffle",
      name=f"{env_id}-seed{seed}",
      config=wandb_config,
      tags=["SAC", "Offline", "Reshuffle"]
      )

    all_indices = np.arange(N)
    steps_so_far = 0
    gmm_fitted = False  # Add flag to track if GMM is fitted

    while steps_so_far < max_steps:
        # sync old networks at start of each epoch
        agent.sync_old_networks()

        for batch in dataset:
            if steps_so_far >= max_steps:
                break

            s_batch, a_batch, r_batch, sn_batch, d_batch = batch


            p_d_tf = tf.ones([s_batch.shape[0], 1], dtype=tf.float32)

            metrics = agent.train_step(s_batch, a_batch, r_batch, sn_batch, d_batch, p_d_tf)
            steps_so_far += 1
            if log_wandb:
                            wandb.log({
                                "step": steps_so_far,
                                "loss/q_rs": metrics['loss_q_rs'],
                                "loss/q_d": metrics['loss_q_d'],
                                "loss/n1": metrics['loss_n1'],
                                "loss/n2": metrics['loss_n2'],
                                "metrics/kl_div": metrics['kl_div'],
                                "loss/actor": metrics['actor_loss'],
                                "loss/alpha": metrics['alpha_loss'],
                                "metrics/alpha_value": metrics['alpha_val'],
                                "metrics/bellman_error": metrics['bellman_err'],
                                 "metrics/bellman_error2": metrics['bellman_err2'],

                                "metrics/value_error": metrics['value_err'],
                                "metrics/weight_mean": metrics['weight_mean'],
                                "metrics/tau": metrics['tau_used']
                            })
            # Log training metrics
            if steps_so_far % 1000 == 0:
                logger.log(metrics, steps_so_far)
                print(f"Step {steps_so_far} / {max_steps} => "
                      f"Q_RS: {metrics['loss_q_rs']:.4f}, "
                      f"Q_D: {metrics['loss_q_d']:.4f}, "
                      f"N1: {metrics['loss_n1']:.4f}, "
                      f"N2: {metrics['loss_n2']:.4f}, "
                      f"Alpha: {metrics['alpha_val']:.4f}, "
                      f"Actor: {metrics['actor_loss']:.4f}")

            if (steps_so_far % config.eval_interval) == 0:
                eval_return_rs, eval_return_d = evaluate_both_policies(agent, env)
                logger.log_eval({
                    'return_rs': eval_return_rs,
                    'return_d': eval_return_d
                }, steps_so_far)
                print(f"[Eval] Step={steps_so_far}, RS Return={eval_return_rs:.2f}, D Return={eval_return_d:.2f}")



    return agent

# ----------------------------------------------------
# 5. Evaluation
# ----------------------------------------------------
def evaluate_offline_policy(agent: ReshuffleSACAgent, env, eval_episodes=5):
    returns = []
    for _ in range(eval_episodes):
        obs = env.reset()
        done = False
        ep_ret = 0
        while not done:
            act = agent.act(obs, deterministic=True)
            obs_next, rew, done, info = env.step(act)
            ep_ret += rew
            obs = obs_next
        returns.append(ep_ret)
    return np.mean(returns)

# ----------------------------------------------------
# 5. Evaluation
# ----------------------------------------------------
def evaluate_offline_policy(agent: ReshuffleSACAgent, env, eval_episodes=5):
    returns = []
    for _ in range(eval_episodes):
        obs = env.reset()
        done = False
        ep_ret = 0
        while not done:
            act = agent.act(obs, deterministic=True)
            obs_next, rew, done, info = env.step(act)
            ep_ret += rew
            obs = obs_next
        returns.append(ep_ret)
    return np.mean(returns)
def evaluate_both_policies(agent, env, eval_episodes=1):
    returns_rs = []
    returns_d = []

    for _ in range(eval_episodes):
        # Evaluate RS policy
        obs = env.reset()
        done = False
        ep_ret = 0
        while not done:
            act = agent.act(obs, deterministic=True)
            obs_next, rew, done, info = env.step(act)
            ep_ret += rew
            obs = obs_next
        returns_rs.append(ep_ret)

        # Evaluate D policy
        obs = env.reset()
        done = False
        ep_ret = 0
        while not done:
            act = agent.act_d(obs, deterministic=True)
            obs_next, rew, done, info = env.step(act)
            ep_ret += rew
            obs = obs_next
        returns_d.append(ep_ret)

    return np.mean(returns_rs), np.mean(returns_d)

# ----------------------------------------------------
# 6. Main
# ----------------------------------------------------
cfg = Config()
# Possibly override config if desired
cfg.max_gradient_steps = 1000000

env = gym.make("HalfCheetah-v4")
agent = ReshuffleSACAgent(env, cfg)

agent = train_offline_sac_reshuffle(agent=agent,
    env_id="HalfCheetah-v4",
    seed=42,
    max_steps=cfg.max_gradient_steps,
    config=cfg,
    log_wandb=False
)
ret = evaluate_offline_policy(agent, env)
print(f"Final deterministic policy evaluation average return: {ret:.2f}")
