In [None]:
import random
import time

import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.training.train_state import TrainState
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from myenv import MyEnv
from distributions import Distribution

import toml
from types import SimpleNamespace

config = toml.load("config.toml")
args = SimpleNamespace(**config)

In [None]:
# Define RBF kernel function
def rbf_kernel(X1, X2, sigma=1.0):
    pairwise_sq_dists = jnp.sum(X1**2, 1).reshape(-1, 1) + jnp.sum(X2**2, 1) - 2 * jnp.dot(X1, X2.T)
    return jnp.exp(-pairwise_sq_dists / (2 * sigma**2))

# Define Kernel Ridge Regression class
class KernelRidgeRegression:
    def __init__(self, lambda_val=0.1, sigma=1.0):
        self.lambda_val = lambda_val
        self.sigma = sigma
        self.alpha = None
        self.X_train = None
    
    def fit(self, X, y):
        K = rbf_kernel(X, X, self.sigma)
        self.alpha = jnp.linalg.solve(K + self.lambda_val * jnp.eye(X.shape[0]), y)
        self.X_train = X
    
    def predict(self, X):
        K = rbf_kernel(X, self.X_train, self.sigma)
        return jnp.dot(K, self.alpha)


In [None]:
# Initialize Actor
class Actor:
    def apply(self, params, observations):
        phi = jnp.arccos(params[0] + params[1]*observations[0] + params[2]*observations[1])
        alpha = params[3]**2 + params[4]**2 * (observations[0] - params[5])**2 + params[6]**2 * (observations[1] - params[7])**2
        beta = params[8]**2 + params[9]**2 * (observations[0] - params[10])**2 + params[11]**2 * (observations[1] - params[12])**2

        t1 = jnp.array([[jnp.cos(phi), -jnp.sin(phi)], [jnp.sin(phi), jnp.cos(phi)]])
        t2 = jnp.array([[alpha, 0], [0, beta]])

        sigma2 = t1 @ t2 @ t1.T

        return sigma2

In [None]:
class ActorTrainState(TrainState):
    params: jnp.ndarray
    target_params: jnp.ndarray

class TrainState(TrainState):
    params: flax.core.FrozenDict
    target_params: flax.core.FrozenDict

In [None]:
# Record the hyperparameters
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

In [None]:
# Random seed
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, actor_key, qf1_key = jax.random.split(key, 3)

In [None]:
# Setup env
log_p = Distribution.gaussian
dim = 2
max_steps=100

env = MyEnv(log_p, dim, max_steps)
max_action = float(env.action_space.high[0])
env.observation_space.dtype = np.float32
rb = ReplayBuffer(
    args.buffer_size,
    env.observation_space,
    env.action_space,
    device='cpu',
    handle_timeout_termination=False,
)

In [None]:
# Start
obs, _ = env.reset()

actor = Actor()
qf1 = KernelRidgeRegression(lambda_val=0.1, sigma=1.0)

In [None]:
actor_state = ActorTrainState.create(
    apply_fn=actor.apply,
    params=jnp.array([0.0, 0.0, 0.0, 1.0, 2.5, 0.0, 2.5, 0.0, 1.0, 2.5, 0.0, 2.5, 0.0]),
    target_params=jnp.array([0.0, 0.0, 0.0, 1.0, 2.5, 0.0, 2.5, 0.0, 1.0, 2.5, 0.0, 2.5, 0.0]),
    tx=optax.adam(
        learning_rate=args.learning_rate
        ),
)

In [None]:
actor.apply = jax.jit(jax.vmap(actor.apply, in_axes=(None, 0)))
qf1.predict = jax.jit(qf1.predict)

In [None]:
def update_critic(
    observations: np.ndarray,
    actions: np.ndarray,
    next_observations: np.ndarray,
    rewards: np.ndarray,
    dones: np.ndarray,
):
    # Prepare the data
    X = jnp.concatenate([observations, actions], axis=-1)
    next_state_actions = (actor.apply(actor_state.target_params, next_observations)).clip(-1, 1)
    qf1_next_target = qf1.predict(jnp.concatenate([next_observations, next_state_actions], axis=-1))
    next_q_value = (rewards + (1 - dones) * args.gamma * (qf1_next_target)).reshape(-1)
    
    # Train Q function using KRR
    qf1.fit(X, next_q_value)
    
    # Compute loss (for logging or other purposes, not used for updating qf1)
    qf1_a_values = qf1.predict(X)
    qf1_loss_value = jnp.mean((qf1_a_values - next_q_value) ** 2)
    
    return qf1_loss_value, qf1_a_values


In [None]:
# Inital the State of Optimizer
opt_state = actor_state.tx.init(actor_state.params)


def update_actor(actor_state, observations, qf1):

    def actor_loss(params, observations, qf1):
        actions = actor.apply(params, observations)
        return -qf1.predict(jnp.concatenate([observations, actions.reshape(args.batch_size, -1)], axis=-1)).mean()

    actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params, observations, qf1)
    actor_state = actor_state.apply_gradients(grads=grads)
    actor_state = actor_state.replace(
        target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau)
    )
    return actor_state, actor_loss_value

In [None]:
start_time = time.time()
for global_step in range(args.total_timesteps):
    # Action
    if global_step < args.learning_starts:
        actions = np.eye(env.dim).flatten()
        policy_cov_func = lambda obs: jnp.eye(env.dim)
    else:
        policy_cov_func = lambda obs: jnp.squeeze(actor.apply(actor_state.params, obs.reshape(1, -1)))
        actions = policy_cov_func(obs).flatten()

    # Execute the env and log data.
    actions_matrix = actions.reshape(env.dim, -1)

    print(actions_matrix)

    next_obs, rewards, terminateds, truncateds, infos = env.step(actions_matrix, policy_cov_func)

    # Record rewards for plotting purposes
    if "final_info" in infos:
        for info in infos["final_info"]:
            print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
            writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
            writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
            break

    # Save data to reply buffer; handle `terminal_observation`
    real_next_obs = next_obs.copy()
    rb.add(obs, real_next_obs, actions, rewards, terminateds, infos)

    # Update observation
    obs = next_obs

    # Training.
    if global_step > args.learning_starts:
        data = rb.sample(args.batch_size)
        qf1_loss_value, qf1_a_values = update_critic(
            data.observations.numpy(),
            data.actions.numpy(),
            data.next_observations.numpy(),
            data.rewards.flatten().numpy(),
            data.dones.flatten().numpy(),
        )
        if global_step % args.policy_frequency == 0:
            actor_state, qf1_state, actor_loss_value = update_actor(
                actor_state,
                data.observations.numpy(),
                qf1,
            )


        if global_step % 100 == 0:
            writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step)
            writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step)
            writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step)
            print("SPS:", int(global_step / (time.time() - start_time)))
            writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

In [None]:
actor.apply(actor_state.params, obs.reshape(1, -1))