In [None]:
import random
import time

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training.train_state import TrainState
from torch.utils.tensorboard import SummaryWriter

import plotly.graph_objects as go

from myenv import MyEnv
from distributions import Distribution

import toml
from types import SimpleNamespace

config = toml.load("config.toml")
args = SimpleNamespace(**config)

In [None]:
class QFunction:
    def __init__(self):
        self.weights = None

    def _basis_functions(self, x, theta):
        x_theta = jnp.einsum('ij,ik->ijk', x, theta).reshape(x.shape[0], -1)
        return jnp.hstack([jnp.ones((x.shape[0], 1)), x, theta, x_theta])

    def fit(self, X, theta, y):
        Phi = self._basis_functions(X, theta)
        if self.weights is None:
            self.weights = jnp.zeros(Phi.shape[1])
        pseudo_inv = jnp.linalg.pinv(Phi.T @ Phi)
        self.weights = pseudo_inv @ Phi.T @ y

    def apply(self, x, theta):
        phi = self._basis_functions(x, theta)
        if self.weights is None:
            self.weights = jnp.zeros(phi.shape[1])
        return jnp.dot(phi, self.weights)


In [None]:
# Initialize Actor
class Actor:
    def apply(self, params, observations):
        phi = jnp.arccos(params[0] + params[1]*observations[0] + params[2]*observations[1])
        alpha = params[3]**2 + params[4]**2 * (observations[0] - params[5])**2 + params[6]**2 * (observations[1] - params[7])**2
        beta = params[8]**2 + params[9]**2 * (observations[0] - params[10])**2 + params[11]**2 * (observations[1] - params[12])**2

        t1 = jnp.array([[jnp.cos(phi), -jnp.sin(phi)], [jnp.sin(phi), jnp.cos(phi)]])
        t2 = jnp.array([[alpha, 0], [0, beta]])

        sigma2 = t1 @ t2 @ t1.T

        return sigma2

In [None]:
class ActorTrainState(TrainState):
    params: jnp.ndarray
    target_params: jnp.ndarray

In [None]:
# Record the hyperparameters
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

In [None]:
# Random seed
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, actor_key, qf1_key = jax.random.split(key, 3)

In [None]:
# Setup env
log_p = Distribution.gaussian
dim = 2
max_steps=100

env = MyEnv(log_p, dim, max_steps)
max_action = float(env.action_space.high[0])
env.observation_space.dtype = np.float32

In [None]:
# Start
obs, _ = env.reset()

actor = Actor()
qf1 = QFunction()

In [None]:
actor_state = ActorTrainState.create(
    apply_fn=actor.apply,
    params=jnp.array([0.0, 0.0, 0.0, 1.0, 2.5, 0.0, 2.5, 0.0, 1.0, 2.5, 0.0, 2.5, 0.0]),
    target_params=jnp.array([0.0, 0.0, 0.0, 1.0, 2.5, 0.0, 2.5, 0.0, 1.0, 2.5, 0.0, 2.5, 0.0]),
    tx=optax.adam(
        learning_rate=args.learning_rate
        ),
)

In [None]:
actor.apply = jax.jit(actor.apply)
qf1.apply = jax.jit(qf1.apply)

In [None]:
def update_critic(
    actor_state: TrainState,
    observations: jnp.ndarray,
    rewards: jnp.ndarray,
    q_function: QFunction
    ):

    next_action = actor.apply(actor_state.params, observations)

    # Calculate current Q value estimate
    q_current = q_function.apply(observations.reshape(1, -1), actor_state.params.reshape(1, -1))

    # Calculate TD target
    q_target = rewards + args.gamma * q_function.apply(observations.reshape(1, -1), actor_state.params.reshape(1, -1))

    # Compute TD error
    td_error = q_target - q_current

    # Compute the loss
    qf1_loss_value = jnp.mean(td_error**2)

    # Use the current state-action pair and TD target update Q function
    q_function.fit(observations.reshape(1, -1), actor_state.params.reshape(1, -1), q_target.reshape(1, -1))

    return qf1_loss_value, q_current

In [None]:
# Inital the State of Optimizer
opt_state = actor_state.tx.init(actor_state.params)

def update_actor(
    actor_state: TrainState,
    observations: jnp.ndarray,
    q_function: QFunction
    ):

    def actor_loss(
        params: jnp.ndarray,
        observations: jnp.ndarray,
        q_function: QFunction
        ):

        return -q_function.apply(observations.reshape(1, -1), params.reshape(1, -1)).mean()

    actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params, observations, q_function)
    actor_state = actor_state.apply_gradients(grads=grads)
    actor_state = actor_state.replace(
        target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau)
    )

    return actor_state, actor_loss_value

In [None]:
start_time = time.time()
for global_step in range(args.total_timesteps):
    # Action
    if global_step < args.learning_starts:
        actions = np.eye(env.dim).flatten()
        policy_cov_func = lambda obs: jnp.eye(env.dim)
    else:
        policy_cov_func = lambda obs: jnp.squeeze(actor.apply(actor_state.params, obs))
        actions = policy_cov_func(obs).flatten()

    # Execute the env and log data.
    actions_matrix = actions.reshape(env.dim, -1)
    next_obs, rewards, terminateds, truncateds, infos = env.step(actions_matrix, policy_cov_func)

    # Record rewards for plotting purposes
    if "final_info" in infos:
        for info in infos["final_info"]:
            print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
            writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
            writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
            break

    # Update observation
    obs = next_obs

    # Training.
    if global_step > args.learning_starts:
        qf1_loss_value, qf1_a_values = update_critic(actor_state, obs, rewards, qf1)
        if global_step % args.policy_frequency == 0:
            actor_state, actor_loss_value = update_actor(
                actor_state,
                obs,
                qf1,
            )

        if global_step % 100 == 0:
            writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step)
            writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step)
            writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step)
            print("SPS:", int(global_step / (time.time() - start_time)))
            writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

In [None]:
# Plot Policy

x0 = np.linspace(-5., 5., 1000)
x1 = x0.copy()
x, y = np.meshgrid(x0, x1)

res = []
for i in range(1000):
    for j in range(1000):
        res.append(np.trace(policy_cov_func(np.array([x[i,j], y[i,j]]))))

z = np.array(res).reshape(1000, 1000)

In [None]:
fig = go.Figure(data=[go.Surface(z=z, x=x, y=y, colorscale='Viridis')])
fig.update_layout(scene=dict(xaxis_title='x',
                             yaxis_title='y',
                             zaxis_title='Trace'))

fig.show()