In [1]:
!pip install gymnasium==0.29.1 --quiet
!pip install numpy==1.26.4 --quiet
!pip install tensorflow==2.17.0 --quiet
!pip install ml-dtypes>=0.3.1 --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m37.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m601.3/601.3 MB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import gymnasium as gym
import os, pickle, math

In [3]:
class ObGD(keras.optimizers.Optimizer):
    def __init__(self, learning_rate=1.0, gamma=0.99, lamda=0.8, kappa=2.0, name='ObGD'):
        """Initialize the optimizer."""
        super().__init__(name=name, learning_rate=learning_rate)
        self.gamma = gamma
        self.lamda = lamda
        self.kappa = kappa
        self.state = {}
        self.param_groups = [{
            'learning_rate': learning_rate,
            'gamma': gamma,
            'lamda': lamda,
            'kappa': kappa
        }]

    def apply_gradients(self, grads_and_vars, delta, reset=False, name=None, **kwargs):
        """Mirror of PyTorch's step method"""
        grads_and_vars = [(g, v) for (g, v) in grads_and_vars if g is not None]

        def get_var_id(var):
            return f"{var.name}_{var.shape}"

        # First pass: Initialize eligibility traces
        for grad, var in grads_and_vars:
            var_id = get_var_id(var)
            if var_id not in self.state:
                self.state[var_id] = {
                    "eligibility_trace": tf.zeros(var.shape, dtype=var.dtype)
                }

        # Second pass: Update eligibility traces and compute z_sum
        z_sum = 0.0
        for group in self.param_groups:
            for grad, var in grads_and_vars:
                var_id = get_var_id(var)
                e = self.state[var_id]["eligibility_trace"]

                # Update eligibility trace
                decay = group['gamma'] * group['lamda']
                new_e = tf.multiply(e, decay)
                new_e = tf.add(new_e, grad)

                self.state[var_id]["eligibility_trace"] = new_e
                z_sum += tf.reduce_sum(tf.abs(new_e))

        # Compute step size
        delta_bar = tf.maximum(tf.abs(delta), 1.0)
        dot_product = delta_bar * z_sum * group['learning_rate'] * group['kappa']
        step_size = tf.where(dot_product > 1,
                           group['learning_rate'] / dot_product,
                           group['learning_rate'])

        # Third pass: Update parameters
        for grad, var in grads_and_vars:
            var_id = get_var_id(var)
            e = self.state[var_id]["eligibility_trace"]
            var.assign_sub(step_size * delta * e)

            if reset:
                self.state[var_id]["eligibility_trace"] = tf.zeros(var.shape, dtype=var.dtype)

        return True

    def get_config(self):
        config = super().get_config()
        config.update({
            'gamma': self.gamma,
            'lamda': self.lamda,
            'kappa': self.kappa,
        })
        return config

In [4]:
def sparse_init(shape, sparsity, type='uniform'):
    if len(shape) == 2:
        fan_out, fan_in = shape
        num_zeros = int(math.ceil(sparsity * fan_in))

        if type == 'uniform':
            tensor = tf.random.uniform(shape, -math.sqrt(1.0/fan_in), math.sqrt(1.0/fan_in))
        elif type == 'normal':
            tensor = tf.random.normal(shape, 0, math.sqrt(1.0/fan_in))
        else:
            raise ValueError("Unknown initialization type")

        mask = tf.ones(shape)
        for col_idx in range(fan_out):
            # Convert to int32
            col_idx = tf.cast(col_idx, tf.int32)
            zero_indices = tf.cast(tf.random.shuffle(tf.range(fan_in))[:num_zeros], tf.int32)
            updates = tf.zeros(num_zeros)
            indices = tf.stack([
                tf.repeat(col_idx, num_zeros),
                zero_indices
            ], axis=1)
            mask = tf.tensor_scatter_nd_update(mask, indices, updates)
        return tensor * mask

    elif len(shape) == 4:
        channels_out, channels_in, h, w = shape
        fan_in = channels_in * h * w
        num_zeros = int(math.ceil(sparsity * fan_in))

        if type == 'uniform':
            tensor = tf.random.uniform(shape, -math.sqrt(1.0/fan_in), math.sqrt(1.0/fan_in))
        elif type == 'normal':
            tensor = tf.random.normal(shape, 0, math.sqrt(1.0/fan_in))
        else:
            raise ValueError("Unknown initialization type")

        mask = tf.ones(shape)
        for out_channel_idx in range(channels_out):
            # Convert to int32
            out_channel_idx = tf.cast(out_channel_idx, tf.int32)
            zero_indices = tf.cast(tf.random.shuffle(tf.range(fan_in))[:num_zeros], tf.int32)
            updates = tf.zeros(num_zeros)
            flat_mask = tf.reshape(mask[out_channel_idx], [-1])
            flat_mask = tf.tensor_scatter_nd_update(
                flat_mask,
                tf.expand_dims(zero_indices, 1),
                updates
            )
            mask = tf.tensor_scatter_nd_update(
                mask,
                tf.expand_dims(out_channel_idx, 0),
                [tf.reshape(flat_mask, [channels_in, h, w])]
            )
        return tensor * mask
    else:
        raise ValueError("Only tensors with 2 or 4 dimensions are supported")

In [5]:
class SampleMeanStd:
    def __init__(self, shape=()):
        self.mean = np.zeros(shape, "float64")
        self.var = np.ones(shape, "float64")
        self.p = np.ones(shape, "float64")
        self.count = 0

    def update(self, x):
        if self.count == 0:
            self.mean = x
            self.p = np.zeros_like(x)
        self.mean, self.var, self.p, self.count = self.update_mean_var_count_from_moments(self.mean, self.p, self.count, x*1.0)

    def update_mean_var_count_from_moments(self, mean, p, count, sample):
        new_count = count + 1
        new_mean = mean + (sample - mean) / new_count
        p = p + (sample - mean) * (sample - new_mean)
        new_var = 1 if new_count < 2 else p / (new_count - 1)
        return new_mean, new_var, p, new_count

class NormalizeObservation(gym.Wrapper):
    def __init__(self, env: gym.Env, epsilon: float = 1e-8):
        super().__init__(env)
        try:
            self.num_envs = self.get_wrapper_attr("num_envs")
            self.is_vector_env = self.get_wrapper_attr("is_vector_env")
        except AttributeError:
            self.num_envs = 1
            self.is_vector_env = False

        if self.is_vector_env:
            self.obs_stats = SampleMeanStd(shape=self.single_observation_space.shape)
        else:
            self.obs_stats = SampleMeanStd(shape=self.observation_space.shape)
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, terminateds, truncateds, infos = self.env.step(action)
        if self.is_vector_env:
            obs = self.normalize(obs)
        else:
            obs = self.normalize(np.array([obs]))[0]
        return obs, rews, terminateds, truncateds, infos

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        if self.is_vector_env:
            return self.normalize(obs), info
        else:
            return self.normalize(np.array([obs]))[0], info

    def normalize(self, obs):
        self.obs_stats.update(obs)
        return (obs - self.obs_stats.mean) / np.sqrt(self.obs_stats.var + self.epsilon)

class ScaleReward(gym.Wrapper):
    def __init__(self, env: gym.Env, gamma: float = 0.99, epsilon: float = 1e-8):
        super().__init__(env)
        try:
            self.num_envs = self.get_wrapper_attr("num_envs")
            self.is_vector_env = self.get_wrapper_attr("is_vector_env")
        except AttributeError:
            self.num_envs = 1
            self.is_vector_env = False
        self.reward_stats = SampleMeanStd(shape=())
        self.reward_trace = np.zeros(self.num_envs)
        self.gamma = gamma
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, terminateds, truncateds, infos = self.env.step(action)
        if not self.is_vector_env:
            rews = np.array([rews])
        term = terminateds or truncateds
        self.reward_trace = self.reward_trace * self.gamma * (1 - term) + rews
        rews = self.normalize(rews)
        if not self.is_vector_env:
            rews = rews[0]
        return obs, rews, terminateds, truncateds, infos

    def normalize(self, rews):
        self.reward_stats.update(self.reward_trace)
        return rews / np.sqrt(self.reward_stats.var + self.epsilon)

In [6]:
class AddTimeInfo(gym.Wrapper):
    def __init__(self, env: gym.Env):
        super().__init__(env)
        if self.env.num_envs > 1:
            raise ValueError("AddTimeInfo only supports single environments")
        self.epi_time = -0.5
        if 'dm_control' in env.spec.id:
            self.time_limit = 1000
        else:
            self.time_limit = env.spec.max_episode_steps
        self.obs_space_size = self.observation_space.shape[0] + self.env.num_envs
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.obs_space_size,), dtype=np.float32)
        if not (isinstance(self.action_space, gym.spaces.Box) or isinstance(self.action_space, gym.spaces.Discrete)):
            raise ValueError("Unsupported action space")

    def step(self, action):
        obs, rews, terminateds, truncateds, infos = self.env.step(action)
        obs = np.concatenate((obs, np.array([self.epi_time] * self.env.num_envs)))
        self.epi_time += 1.0 / self.time_limit
        return obs, rews, terminateds, truncateds, infos

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.epi_time = -0.5
        obs = np.concatenate((obs, np.array([self.epi_time])))
        return obs, info

In [None]:
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

class StreamQ(keras.Model):
    def __init__(self, n_obs=11, n_actions=3, hidden_size=32, lr=1.0, epsilon_target=0.01,
                 epsilon_start=1.0, exploration_fraction=0.1, total_steps=1_000_000,
                 gamma=0.99, lamda=0.8, kappa_value=2.0):
        super(StreamQ, self).__init__()
        self.n_actions = n_actions
        self.gamma = gamma
        self.epsilon_start = epsilon_start
        self.epsilon_target = epsilon_target
        self.epsilon = epsilon_start
        self.exploration_fraction = exploration_fraction
        self.total_steps = total_steps
        self.time_step = 0

        # Define layers
        self.fc1_v = keras.layers.Dense(
            hidden_size,
            kernel_initializer=lambda shape, dtype: sparse_init(shape, 0.9),
            bias_initializer='zeros'
        )
        self.hidden_v = keras.layers.Dense(
            hidden_size,
            kernel_initializer=lambda shape, dtype: sparse_init(shape, 0.9),
            bias_initializer='zeros'
        )
        self.fc_v = keras.layers.Dense(
            n_actions,
            kernel_initializer=lambda shape, dtype: sparse_init(shape, 0.9),
            bias_initializer='zeros'
        )
        self.layer_norm1 = keras.layers.LayerNormalization()
        self.layer_norm2 = keras.layers.LayerNormalization()

        self.optimizer = ObGD(learning_rate=lr, gamma=gamma, lamda=lamda, kappa=kappa_value)

    def call(self, x):
        if isinstance(x, np.ndarray):
            x = tf.convert_to_tensor(x, dtype=tf.float32)
        # Ensure input has batch dimension
        if len(x.shape) == 1:
            x = tf.expand_dims(x, 0)
        x = self.fc1_v(x)
        x = self.layer_norm1(x)
        x = tf.nn.leaky_relu(x)
        x = self.hidden_v(x)
        x = self.layer_norm2(x)
        x = tf.nn.leaky_relu(x)
        return self.fc_v(x)

    def sample_action(self, s):
        self.time_step += 1
        self.epsilon = linear_schedule(
            self.epsilon_start,
            self.epsilon_target,
            self.exploration_fraction * self.total_steps,
            self.time_step
        )
        if np.random.rand() < self.epsilon:
            q_values = self.call(s)
            greedy_action = tf.argmax(q_values, axis=-1).numpy()[0]
            random_action = np.random.randint(0, self.n_actions)
            if greedy_action == random_action:
                return random_action, False
            else:
                return random_action, True
        else:
            q_values = self.call(s)
            return tf.argmax(q_values, axis=-1).numpy()[0], False

    def update_params(self, s, a, r, s_prime, done, is_nongreedy, overshooting_info=False):
        done_mask = 0.0 if done else 1.0
        s = tf.convert_to_tensor(s, dtype=tf.float32)
        s_prime = tf.convert_to_tensor(s_prime, dtype=tf.float32)
        r = tf.convert_to_tensor(r, dtype=tf.float32)
        done_mask = tf.convert_to_tensor(done_mask, dtype=tf.float32)

        # Ensure inputs have batch dimension
        if len(s.shape) == 1:
            s = tf.expand_dims(s, 0)
        if len(s_prime.shape) == 1:
            s_prime = tf.expand_dims(s_prime, 0)

        with tf.GradientTape() as tape:
            q_sa = self.call(s)[0][a]
            max_q_s_prime_a_prime = tf.reduce_max(self.call(s_prime)[0])
            td_target = r + self.gamma * max_q_s_prime_a_prime * done_mask
            delta = td_target - q_sa
            loss = -q_sa

        grads = tape.gradient(loss, self.trainable_variables)
        grads_and_vars = list(zip(grads, self.trainable_variables))

        # Log gradient stats occasionally (every 1000 steps)
        # if self.time_step % 1000 == 0:
        #     grad_norms = []
        #     for grad, var in grads_and_vars:
        #         if grad is not None:
        #             grad_norms.append(tf.norm(grad).numpy())
        #     if grad_norms:
        #         print(f"Step {self.time_step}, Gradient Stats - Mean: {np.mean(grad_norms):.6f}, Std: {np.std(grad_norms):.6f}, Max: {np.max(grad_norms):.6f}, Min: {np.min(grad_norms):.6f}")

        self.optimizer.apply_gradients(grads_and_vars, delta.numpy(), reset=(done or is_nongreedy))

        if overshooting_info:
            max_q_s_prime_a_prime = tf.reduce_max(self.call(s_prime)[0])
            td_target = r + self.gamma * max_q_s_prime_a_prime * done_mask
            delta_bar = td_target - self.call(s)[0][a]
            if tf.sign(delta_bar * delta).numpy() == -1:
                print("Overshooting Detected!")

In [8]:
def agent_env_interaction(env_name, seed, lr, gamma, lamda, total_steps, epsilon_target,
                         epsilon_start, exploration_fraction, kappa_value, debug,
                         overshooting_info, render=False):
    tf.random.set_seed(seed)
    np.random.seed(seed)

    env = gym.make(env_name, render_mode='human', max_episode_steps=10_000) if render else gym.make(env_name, max_episode_steps=10_000)
    env = gym.wrappers.FlattenObservation(env)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    env = ScaleReward(env, gamma=gamma)
    env = NormalizeObservation(env)
    env = AddTimeInfo(env)

    agent = StreamQ(
        n_obs=env.observation_space.shape[0],
        n_actions=env.action_space.n,
        lr=lr,
        gamma=gamma,
        lamda=lamda,
        epsilon_target=epsilon_target,
        epsilon_start=epsilon_start,
        exploration_fraction=exploration_fraction,
        total_steps=total_steps,
        kappa_value=kappa_value
    )

    if debug:
        print("seed: {}".format(seed), "env: {}".format(env.spec.id))

    returns, term_time_steps = [], []
    s, _ = env.reset(seed=seed)
    episode_num = 1

    for t in range(1, total_steps+1):
        a, is_nongreedy = agent.sample_action(s)
        s_prime, r, terminated, truncated, info = env.step(a)
        agent.update_params(s, a, r, s_prime, terminated or truncated, is_nongreedy, overshooting_info)
        s = s_prime

        if terminated or truncated:
            if debug:
                print("Episodic Return: {}, Time Step {}, Episode Number {}, Epsilon {}".format(
                    info['episode']['r'][0], t, episode_num, agent.epsilon))
            returns.append(info['episode']['r'][0])
            term_time_steps.append(t)
            terminated, truncated = False, False
            s, _ = env.reset()
            episode_num += 1

    env.close()
    save_dir = "data_stream_q_{}_lr{}_gamma{}_lamda{}".format(env.spec.id, lr, gamma, lamda)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    with open(os.path.join(save_dir, "seed_{}.pkl".format(seed)), "wb") as f:
        pickle.dump((returns, term_time_steps, env_name), f)

In [9]:
agent_env_interaction(
    env_name='CartPole-v1',
    seed=0,
    lr=1.0,
    gamma=0.99,
    lamda=0.8,
    total_steps=100_000,
    epsilon_target=0.01,
    epsilon_start=1.0,
    exploration_fraction=0.05,
    kappa_value=2.0,
    debug=True,
    overshooting_info=False
)

seed: 0 env: CartPole-v1
Episodic Return: 18.0, Time Step 18, Episode Number 1, Epsilon 0.996436
Episodic Return: 46.0, Time Step 64, Episode Number 2, Epsilon 0.987328
Episodic Return: 27.0, Time Step 91, Episode Number 3, Epsilon 0.981982
Episodic Return: 12.0, Time Step 103, Episode Number 4, Epsilon 0.979606
Episodic Return: 10.0, Time Step 113, Episode Number 5, Epsilon 0.977626
Episodic Return: 43.0, Time Step 156, Episode Number 6, Epsilon 0.969112
Episodic Return: 32.0, Time Step 188, Episode Number 7, Epsilon 0.962776
Episodic Return: 18.0, Time Step 206, Episode Number 8, Epsilon 0.959212
Episodic Return: 49.0, Time Step 255, Episode Number 9, Epsilon 0.94951
Episodic Return: 26.0, Time Step 281, Episode Number 10, Epsilon 0.944362
Episodic Return: 30.0, Time Step 311, Episode Number 11, Epsilon 0.938422
Episodic Return: 26.0, Time Step 337, Episode Number 12, Epsilon 0.933274
Episodic Return: 19.0, Time Step 356, Episode Number 13, Epsilon 0.929512
Episodic Return: 16.0, Tim

KeyboardInterrupt: 