In [None]:
from functools import partial
import jax
import os
import html

from datetime import datetime
from jax import numpy as jp
jax.config.update("jax_debug_nans", False)
jax.config.update("jax_enable_x64", True)
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output, display
from brax.io import html

import shutil

from src.pods.Pods import train as train_pods
from src.hds.Hds import train as train_hds
from src.envs.original import Pendulum 
from src.envs.realistic import RealisticPendulum 
from src.dyn_model.Predict import pretrained_params




In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
print(jax.devices())

In [None]:
env = Pendulum.Pendulum()
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

#HTML(html.render(env.sys, [state.pipeline_state]))

In [None]:
ckpt_dir = './tmp/flax_ckpt'

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

In [None]:

def progress_f(x_data,y_data,epoch,reward):
    print(reward)
    x_data.append(epoch)
    y_data.append(reward)
    clear_output(wait=True)
    plt.xlabel('Epoch')
    plt.ylabel('Total Reward')
    plt.plot(x_data, y_data)
    plt.show()

policy_hds = train_hds(env,trajectory_length=100 ,num_samples=100,epochs=20, inner_epochs=20, alpha_a=8e-5, aggregation_factor_beta=0.175,
                       init_learning_rate=1e-3, init_noise=0.1, noise_decay=0.8, progress_fn=progress_f) 

In [None]:
policy_pods = train_pods(env,trajectory_length=100 ,num_samples=100,epochs=20, inner_epochs=20, alpha_a=8e-5, init_learning_rate=1e-3, progress_fn=progress_f)

In [None]:
realistic_env = RealisticPendulum.RealisticPendulum()
inference_fn_hds = policy_hds()
inference_fn_pods = policy_pods()

@partial(jax.vmap, in_axes=(None, None, None, 0), axis_name="batch")
def rollout_policy(
        env, inference_fn, trajectory_length: int, prng_keys
    ):

    def step_trajectory(state_carry, rng_key):
        action = inference_fn(state_carry.obs)
        next_state = env.step(state_carry, action)
        return next_state, (state_carry.obs, action, next_state.reward)

    state = env.reset(prng_keys)
    keys = jax.random.split(prng_keys, trajectory_length)
    _, (states, actions, rewards_future) = jax.lax.scan(
        step_trajectory, state, xs=keys
    )

    states = jax.numpy.reshape(
        states, (trajectory_length, env.observation_size)
    )
    actions = jax.numpy.reshape(
        actions, (trajectory_length, env.action_size)
    )

    totalreward = jp.sum(rewards_future)

    return states, actions, totalreward

prng_key = jax.random.PRNGKey(seed=0)
subkeys = jax.random.split(prng_key, num=100)

_, _ , rewards_original_pods  = rollout_policy(env, inference_fn_pods, 100, subkeys)
_, _ , rewards_realistic = rollout_policy(realistic_env, inference_fn_pods, 100, subkeys)
_, _ , rewards_hds = rollout_policy(realistic_env, inference_fn_hds, 100, subkeys)
_, _ , rewards_hds_train_env = rollout_policy(env, inference_fn_hds, 100, subkeys)

average_reward_original_pods = jp.mean(rewards_original_pods)
average_reward_realistic = jp.mean(rewards_realistic)
average_reward_hds = jp.mean(rewards_hds)
average_reward_hds_train_env = jp.mean(rewards_hds_train_env)

In [None]:
print(f'PODS on real environment {jp.mean(jp.array(average_reward_realistic))}')
print(f'PODS on train environment {jp.mean(jp.array(average_reward_original_pods))}')
print(f'HDS on real environment {jp.mean(jp.array(average_reward_hds))}')
print(f'HDS on train environment {jp.mean(jp.array(average_reward_hds_train_env))}')

In [None]:
seeds = [x for x in range(100)]
plt.figure(figsize=(10, 6))
plt.plot(seeds, rewards_realistic, label='PODS on Real Environment', color='blue')
plt.plot(seeds, rewards_original_pods, label='PODS on Train Environment', color='red')
plt.plot(seeds, rewards_hds, label='HDS on Real Environment (Ours)', color='green')
plt.plot(seeds, rewards_hds_train_env, label='HDS on Train Environment (Ours)', color='orange')
plt.xlabel('Seed')
plt.ylabel('Total Reward')
plt.title('Total Reward across Seeds')
plt.legend()
plt.grid(True)
plt.show()

In [None]:

yes=html.render(env.sys.replace(dt=env.dt), rollout_realistic)
with open("data.html", "w") as file:
    file.write(yes)