In [None]:
import random
from functools import partial

import jax
from jax.scipy.stats import norm
import jax.numpy as jnp
import numpy as np

from rlax import dpg_loss, td_learning

from myenv import MyEnv
from base_rl_mcmc.actor_critics import PolicyFunction, QFunction
from base_rl_mcmc.replay_buffer import ReplayBuffer

import toml
from types import SimpleNamespace

config = toml.load("./base_rl_mcmc/config.toml")
args = SimpleNamespace(**config)

# Random seed
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, actor_key, qfOLS_key = jax.random.split(key, 3)

In [None]:
class QFunctionOLS:
    def __init__(self):
        pass

    def __call__(self, state, action, weights):
        """
        Forwards pass of the State Action Function (Q-function).
        """
        return weights * action**2 / (1 + (action-jnp.abs(state))**2)

    def grad_weights(self, state, action, weights):
        """
        Backwards pass of the Q-function with respect to the weights.
        """
        return jax.jacfwd(self.__call__, argnums=2)(state, action, weights)

    def grad_action(self, state, action, weights):
        """
        Backwards pass of the Q-function with respect to the action.
        """
        return jax.jacfwd(self.__call__, argnums=1)(state, action, weights)

    def cumulative_return(self, state_list, gamma=args.gamma):
        """
        Compute the cumulative return of a trajectory.
        """
        return jnp.sum([gamma**(i) * (state_list[i+1] - state_list[i])**2 for i in range(len(state_list) - 1)])

    def update_weights_least_square(self, state_list, action_list):
        """
        Update the weights of the Q-function using least square.
        """
        state_list_dash = state_list.reshape(-1, 1)
        action_list_dash = action_list.reshape(-1, 1)
        M = np.concatenate((state_list_dash, action_list_dash), axis=1)
        G = self.cumulative_return(state_list_dash)

        return jnp.linalg.lstsq(M, G, rcond=None)[0]

    def update_weights_TD(self, state, action, next_state, next_action, weights, learning_rate=args.learning_rate, gamma=args.gamma):
        """
        Update the weights of the Q-function using TD(0) error.
        """
        v_tm1 = self.__call__(state, action, weights)
        v_t = self.__call__(next_state, next_action, weights)
        r_t = jnp.power(jnp.linalg.norm(next_state - state, 2), 2)

        return weights + learning_rate * td_learning(v_tm1=v_tm1, r_t=r_t, discount_t=gamma, v_t=v_t, stop_target_gradients=True)

In [None]:
class PolicyFunctionOLS(QFunctionOLS):
    def __init__(self):
        pass

    def __call__(self, state, theta):
        """
        Forwards pass of the Policy Function.
        """
        return theta[0]**2 + theta[1]**2 * jnp.abs(state)

    def grad_theta(self, state, theta):
        """
        Backwards pass of the Policy Function with respect to theta.
        """
        return jax.jacfwd(self.__call__, argnums=1)(state, theta)

    def update_theta_Q(self, state, action, weights, theta, learning_rate=args.learning_rate):
        dqda_t = self.grad_action(state, action, weights)
        return theta + learning_rate * dpg_loss(a_t=action, dqda_t=dqda_t, dqda_clipping=None, use_stop_gradient=True)


In [None]:
def train():
    pass

In [None]:
# Setup env
log_p = partial(norm.logpdf, loc=0, scale=1)
dim = 1
max_steps = 10_000

env = MyEnv(log_p, dim, max_steps)
max_action = float(env.action_space.high[0])
env.observation_space.dtype = np.float32

In [None]:
# Start
obs, _ = env.reset()

actor = PolicyFunctionOLS()
critics = QFunctionOLS()