In [None]:
import random
from functools import partial

import jax
from jax.scipy.stats import norm
import jax.numpy as jnp
import numpy as np

from rlax import dpg_loss, td_learning, add_gaussian_noise

from myenv import MyEnv
from base_rl_mcmc.replay_buffer import ReplayBuffer

import toml
from types import SimpleNamespace

import matplotlib.pyplot as plt

from tqdm.auto import trange

config = toml.load("./base_rl_mcmc/config.toml")
args = SimpleNamespace(**config)

# Random seed
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, actor_key, critics_key = jax.random.split(key, 3)

In [None]:
class QFunctionOLS:
    def __init__(self, init_weights):
        self.weights = init_weights

    def __call__(self, state, action, weights):
        """
        Forwards pass of the State Action Function (Q-function).
        """
        return weights * action**2 / (1 + (action-jnp.abs(state))**2)

    def forward(self, state, action):
        return self.__call__(state, action, self.weights)

    def grad_weights(self, state, action):
        """
        Backwards pass of the Q-function with respect to the weights.
        """
        def closure(state, action, weights):
            return self.__call__(state, action, weights)

        return jax.jacfwd(closure, argnums=2)(state, action, self.weights)

    def grad_action(self, state, action):
        """
        Backwards pass of the Q-function with respect to the action.
        """
        return jax.jacfwd(self.__call__, argnums=1)(state, action, self.weights)

    def cumulative_return(self, state_list, gamma=args.gamma):
        """
        Compute the cumulative return of a trajectory.
        """
        return jnp.sum([gamma**(i) * (state_list[i+1] - state_list[i])**2 for i in range(len(state_list) - 1)])

    def update_weights_least_square(self, state_list, action_list):
        """
        Update the weights of the Q-function using least square.
        """
        state_list_dash = state_list.reshape(-1, 1)
        action_list_dash = action_list.reshape(-1, 1)
        M = np.concatenate((state_list_dash, action_list_dash), axis=1)
        G = self.cumulative_return(state_list_dash)

        return jnp.linalg.lstsq(M, G, rcond=None)[0]

    def update_weights_TD(self, state, action, next_state, next_action, alpha, omega, learning_rate=args.learning_rate, gamma=args.gamma):
        """
        Update the weights of the Q-function using TD(0) error.
        """
        v_tm1 = self.forward(state, action).squeeze()
        v_t = (omega * ( (state - next_state)**2 * alpha + gamma * self.forward(next_state, next_action) )).squeeze()
        # r_t = jnp.power(jnp.linalg.norm(next_state - state, 2), 2)

        # self.weights = self.weights + learning_rate * td_learning(v_tm1=v_tm1, r_t=r_t, discount_t=gamma, v_t=v_t, stop_target_gradients=True) * self.grad_weights(state, action)
        self.weights = self.weights + learning_rate * (v_t - v_tm1) * self.grad_weights(state, action)

In [None]:
class PolicyFunctionOLS:
    def __init__(self, init_theta):
        self.theta = init_theta

    def __call__(self, state, theta):
        """
        Forwards pass of the Policy Function.
        """
        return theta[0]**2 + theta[1]**2 * jnp.abs(state)

    def forward(self, state):
        return self.__call__(state, self.theta)

    def grad_theta(self, state):
        """
        Backwards pass of the Policy Function with respect to theta.
        """
        def closure(state, theta):
            return self.__call__(state, theta)

        return jax.jacfwd(closure, argnums=1)(state, self.theta)

    def update_theta_DPG(self, state, dqda_t, omega, learning_rate=args.learning_rate):
        self.theta = (self.theta \
            + learning_rate \
            * omega \
            # * dpg_loss(a_t=action, dqda_t=dqda_t.flatten(), dqda_clipping=None, use_stop_gradient=True) \
            * self.grad_theta(state)).flatten() \
            * dqda_t

In [None]:
# Setup env
log_p = partial(norm.logpdf, loc=0, scale=1)

env = MyEnv(log_p, dim=1, max_steps=args.total_timesteps)
max_action = float(env.action_space.high[0])
env.observation_space.dtype = np.float32

rb = ReplayBuffer(
    capacity=args.total_timesteps,
    state_dim=env.dim,
    action_dim=env.action_space.shape[0]
    )

In [None]:
# Initialize Environment
obs, _ = env.reset()

init_weights = jnp.array([1.0])
init_theta = jnp.array([1.0, 1.0])

critics = QFunctionOLS(init_weights)
actor = PolicyFunctionOLS(init_theta)

In [None]:
for global_step in trange(args.total_timesteps):
    # key, actor_key, critics_key = jax.random.split(key, 3)

    # Action Process
    action = actor.forward(obs)
    sigma = action.reshape(-1, 1)

    print(sigma)

    next_obs, reward, terminateds, truncateds, infos = env.step(
        action=actor.forward(obs).reshape(-1, 1),
        policy_func=lambda x: actor.forward(x).reshape(-1, 1),
        noise_policy_func=lambda x: 0.9 + 0.2 * np.random.uniform() * actor.forward(x)
    )
    real_next_obs = next_obs.copy()

    print(infos["omega"])

    # Training Session
    # Update Critics
    critics.update_weights_TD(
        state=obs,
        action=action,
        next_state=next_obs,
        next_action=actor.forward(next_obs),
        alpha=infos["alpha"],
        omega=infos["omega"]
        )

    # Update Actor
    if global_step % args.policy_frequency == 0:
        dqda_t = (critics.grad_action(obs, action)).flatten()
        actor.update_theta_DPG(
            state=obs,
            dqda_t=dqda_t,
            omega=infos["omega"]
            )

    # Swap observation
    obs = next_obs