# Dueling DQN with AutoEncoder

In [None]:
%env TF_FORCE_GPU_ALLOW_GROWTH=true
%env WANDB_AGENT_MAX_INITIAL_FAILURES=1024

import wandb
import gymnasium
import numpy as np
import tensorflow as tf
from sklearn.manifold import TSNE

tf.config.optimizer.set_jit(True)  # Enable XLA.

## Hyperparameter

In [None]:
wandb.login()

sweep_config = {
    "method": "random",
    "metric": {"goal": "maximize", "name": "score"},
    "parameters": {
        "epochs": {"value": 2000},
        "buffer_size": {"value": 1000000},
        "batch_size": {"value": 256},
        "lr": {"value": 3e-4},
        "global_clipnorm": {"value": 1.0},
        "tau": {"value": 0.01},
        "gamma": {"value": 0.99},
        "temp_init": {"value": 1.0},
        "temp_min": {"value": 0.01},
        "temp_decay": {"value": 1e-5},
    },
}

sweep_id = wandb.sweep(sweep_config, project="Dueling-DQN-with-AutoEncoder")

## Model

In [None]:
class DuelingDQN(tf.keras.Model):
    def __init__(self, action_space):
        super(DuelingDQN, self).__init__()

        self.fc1 = tf.keras.layers.Dense(
            512,
            activation="elu",
            kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
        )
        self.fc2 = tf.keras.layers.Dense(
            256,
            activation="elu",
            kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
        )
        self.V = tf.keras.layers.Dense(
            1,
            activation=None,
            kernel_initializer=tf.keras.initializers.Orthogonal(0.01),
        )
        self.A = tf.keras.layers.Dense(
            action_space,
            activation=None,
            kernel_initializer=tf.keras.initializers.Orthogonal(0.01),
        )

    def call(self, inputs, training=None):
        x = self.fc1(inputs, training=training)
        x = self.fc2(x, training=training)
        V = self.V(x, training=training)
        A = self.A(x, training=training)
        adv_mean = tf.reduce_mean(A, axis=-1, keepdims=True)
        return V + (A - adv_mean)

    def get_action(self, state, temperature):
        return tf.random.categorical(self(state) / temperature, 1)[0, 0]

In [None]:
class IntrinsicModel(tf.keras.Model):
    def __init__(self):
        super(IntrinsicModel, self).__init__()

        # init reward normalization
        self.rew_rms = tf.keras.layers.Normalization()

    def build(self, input_shape):
        # Encoder
        self._encoder = [
            tf.keras.layers.Dense(
                128,
                activation="elu",
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
            ),
            tf.keras.layers.Dense(
                64,
                activation="elu",
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
            ),
            tf.keras.layers.Dense(
                32,
                activation="elu",
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
            ),
            tf.keras.layers.Dense(
                16,
                activation="elu",
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
            ),
            # Latent space
            tf.keras.layers.Dense(
                (input_shape[-1] // 2),
                activation=None,
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
                name="latent_space",
            ),
        ]
        # Decoder
        self._decoder = [
            tf.keras.layers.Dense(
                16,
                activation="elu",
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
            ),
            tf.keras.layers.Dense(
                32,
                activation="elu",
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
            ),
            tf.keras.layers.Dense(
                64,
                activation="elu",
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
            ),
            tf.keras.layers.Dense(
                128,
                activation="elu",
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
            ),
            tf.keras.layers.Dense(
                input_shape[-1],
                activation=None,
                kernel_initializer=tf.keras.initializers.Orthogonal(tf.sqrt(2.0)),
                name="reconstruction",
            ),
        ]
        self.rew_rms.build((None, 1))
        super(IntrinsicModel, self).build(input_shape)

    def call(self, inputs, training=None, only_encoder=None):
        for l in self._encoder:
            inputs = l(inputs, training=training)
        if not only_encoder:
            for l in self._decoder:
                inputs = l(inputs, training=training)
        return inputs

    def get_int_reward(self, inputs):
        inputs = tf.cast(inputs, dtype=self.dtype)
        y_pred = self(inputs, training=False)
        reward = tf.reduce_sum(tf.square(inputs - y_pred), axis=-1, keepdims=True)
        return reward
    
    def update_reward_norm(self, reward):
        self.rew_rms.update_state(reward)
        self.rew_rms.finalize_state()

    def normalize_reward(self, reward):
        return tf.nn.relu6(self.rew_rms(reward))

## Replay buffer

In [None]:
class ReplayBuffer:
    def __init__(self, shape, size=1e6):
        self.size = int(size)
        self.counter = 0
        self.state_buffer = np.zeros((self.size, *shape), dtype=np.float32)
        self.action_buffer = np.zeros(self.size, dtype=np.int32)
        self.ext_reward_buffer = np.zeros(self.size, dtype=np.float32)
        self.new_state_buffer = np.zeros((self.size, *shape), dtype=np.float32)
        self.terminal_buffer = np.zeros(self.size, dtype=np.bool_)

    def __len__(self):
        return self.counter

    def add(self, state, action, ext_reward, new_state, done):
        idx = self.counter % self.size
        self.state_buffer[idx] = state
        self.action_buffer[idx] = action
        self.ext_reward_buffer[idx] = ext_reward
        self.new_state_buffer[idx] = new_state
        self.terminal_buffer[idx] = done
        self.counter += 1

    def sample(self, batch_size):
        max_buffer = min(self.counter, self.size)
        batch = np.random.choice(max_buffer, batch_size, replace=False)
        state_batch = self.state_buffer[batch]
        action_batch = self.action_buffer[batch]
        ext_reward_batch = self.ext_reward_buffer[batch]
        new_state_batch = self.new_state_buffer[batch]
        done_batch = self.terminal_buffer[batch]

        return (
            state_batch,
            action_batch,
            ext_reward_batch,
            new_state_batch,
            done_batch,
        )

## Learning

In [None]:
def update_target(net, net_targ, tau):
    for source_weight, target_weight in zip(
        net.trainable_variables, net_targ.trainable_variables
    ):
        target_weight.assign(tau * source_weight + (1.0 - tau) * target_weight)


def train_step(dqn, target_dqn, int_model, replay_buffer, batch_size, tau, gamma):
    (
        states,
        actions,
        ext_rewards,
        next_states,
        dones,
    ) = replay_buffer.sample(batch_size)

    # predict next Q
    next_Q = target_dqn(next_states)
    next_Q = tf.reduce_max(next_Q, axis=-1)

    # update int model
    int_loss = int_model.train_on_batch(states, states)
    int_rewards = int_model.get_int_reward(next_states)
    int_model.update_reward_norm(int_rewards)
    int_rewards = int_model.normalize_reward(int_rewards)

    # get targets
    targets = np.array(dqn(states))
    # for experiments with only extrinsic reward, use 'ext_rewards' instead of 'int_rewards'
    targets[np.arange(batch_size), actions] = int_rewards[:, 0] + (
        (1.0 - tf.cast(dones, dtype=tf.float32)) * gamma * next_Q
    )

    # update dqn
    dqn_loss = dqn.train_on_batch(states, targets)

    # soft update target Q
    update_target(dqn, target_dqn, tau=tau)

    return dqn_loss, int_loss

## Run

In [None]:
def run(config=None):
    with wandb.init(config=config):
        config = wandb.config

        # make an environment
        # env = gymnasium.make("CartPole-v1")
        # env = gymnasium.make('MountainCar-v0')
        env = gymnasium.make("LunarLander-v2")
        # env = gymnasium.make('Acrobot-v1')

        state_space = env.observation_space.shape[0]
        action_space = env.action_space.n

        # init variables
        total_steps = 0
        temp = config.temp_init

        # init models
        q_model = DuelingDQN(action_space)
        q_model.compile(
            optimizer=tf.keras.optimizers.Adam(
                learning_rate=config.lr,
                global_clipnorm=config.global_clipnorm,
            ),
            loss="logcosh",
        )
        target_q_model = DuelingDQN(action_space)
        target_q_model.set_weights(q_model.get_weights())
        int_model = IntrinsicModel()
        int_model.compile(
            optimizer=tf.keras.optimizers.Adam(
                learning_rate=config.lr,
                global_clipnorm=config.global_clipnorm,
            ),
            loss="mse",
        )

        # init buffer
        exp_buffer = ReplayBuffer(shape=env.observation_space.shape, size=config.buffer_size)

        # play
        for epoch in range(0, config.epochs):
            state, _ = env.reset()
            done, truncated = False, False
            ep_ext_reward, ep_int_reward, ep_step = 0, 0, 0

            while (not done) and (not truncated):
                action = q_model.get_action(
                    tf.expand_dims(state, axis=0),
                    temp,
                )
                action = np.array(action, copy=False, dtype=env.env.action_space.dtype)

                next_state, ext_reward, done, truncated, _ = env.step(action)

                # get intrinsic reward
                int_reward = int_model.get_int_reward(
                    tf.expand_dims(next_state, axis=0)
                )
                int_reward = int_model.normalize_reward(int_reward)

                exp_buffer.add(state, action, ext_reward, next_state, done)

                state = next_state
                total_steps += 1
                ep_step += 1
                ep_ext_reward += ext_reward
                ep_int_reward += int_reward[0, 0]

                # decrement temperature
                temp -= config.temp_decay
                temp = max(config.temp_min, temp)

                if len(exp_buffer) >= config.batch_size:
                    dqn_loss, int_loss = train_step(
                        q_model,
                        target_q_model,
                        int_model,
                        exp_buffer,
                        config.batch_size,
                        config.tau,
                        config.gamma,
                    )
                    wandb.log(
                        {
                            "dqn_loss": dqn_loss,
                            "int_loss": int_loss,
                        },
                        commit=False,
                    )

            wandb.log(
                {
                    "score": ep_ext_reward,
                    "int_score": ep_int_reward,
                    "steps": ep_step,
                    "temperature": temp,
                },
                commit=True,
                step=epoch,
            )
            print(
                f"Int reward - Mean: {int_model.rew_rms.mean}, Stddev: {tf.sqrt(int_model.rew_rms.variance)}, Count: {int_model.rew_rms.count}"
            )

            # For 'Acrobot-v1' the threshold is -92
            # For 'MountainCar-v0' the threshold is 'env.spec.reward_threshold'
            # For 'CartPole-v1' the threshold is 'env.spec.reward_threshold'
            # For 'LunarLander-v2' the threshold is 'env.spec.reward_threshold'
            if ep_ext_reward > env.spec.reward_threshold:
                q_model.summary()
                target_q_model.summary()
                int_model.summary()

                print(
                    f"\n{env.spec.id} is sloved! {(epoch+1)} Episode in {total_steps} steps with reward {ep_ext_reward}"
                )

                # show latent space
                latent_space = state_space // 2
                sampled_states, _, _, _, _ = exp_buffer.sample(ep_step)
                code_normal = int_model(sampled_states, only_encoder=True)
                code_anomaly = int_model(
                    exp_buffer.new_state_buffer[
                        len(exp_buffer) - ep_step : len(exp_buffer)
                    ],
                    only_encoder=True,
                )
                if latent_space > 2:
                    t_sne = TSNE(n_components=2)
                    code_normal = t_sne.fit_transform(code_normal)

                    t_sne = TSNE(n_components=2)
                    code_anomaly = t_sne.fit_transform(code_anomaly)

                    data = [
                        [x, y, "Normal"]
                        for (x, y) in zip(code_normal[:, 0], code_normal[:, 1])
                    ]
                    data += [
                        [x, y, "Winner"]
                        for (x, y) in zip(code_anomaly[:, 0], code_anomaly[:, 1])
                    ]
                elif latent_space == 2:
                    data = [
                        [x, y, "Normal"]
                        for (x, y) in zip(code_normal[:, 0], code_normal[:, 1])
                    ]
                    data += [
                        [x, y, "Winner"]
                        for (x, y) in zip(code_anomaly[:, 0], code_anomaly[:, 1])
                    ]
                else:
                    data = [
                        [x, y, "Normal"]
                        for (x, y) in zip(
                            code_normal[:, 0], np.ones_like(code_normal[:, 0])
                        )
                    ]
                    data += [
                        [x, y, "Winner"]
                        for (x, y) in zip(
                            code_anomaly[:, 0], np.ones_like(code_anomaly[:, 0])
                        )
                    ]

                table = wandb.Table(data=data, columns=["LS 1", "LS 2", "Type"])
                wandb.log(
                    {
                        "Latent space": wandb.plot.scatter(table, "LS 1", "LS 2"),
                    }
                )

                break

        env.close()

In [None]:
wandb.agent(sweep_id, run, count=100)