# Test suite for env_template_sa

In [1]:
from marketsai.rbc.env_rbc import Rbc
import numpy as np
import seaborn as sn
import matplotlib.pyplot as plt
import time

In [2]:
# environment config
env_config = {
    "horizon": 200,
    "eval_mode": True,
    "analysis_mode": False,
    "simul_mode": False,
    "max_action": 0.6,
    # "rew_mean": 0.9200565795467147,
    # "rew_std": 0.3003009455512563,
    "rew_mean": 0,
    "rew_std": 1,
    "parameters": {
        "alpha": 0.36,
        "delta": 0.025,
        "beta": 0.99,
    },
}


def process_rewards(r, BETA):
    discounted_r = np.zeros_like(r)
    running_add = 0
    for t in reversed(range(0, len(r))):
        running_add = running_add * BETA + r[t]
        discounted_r[t] = running_add
    return discounted_r[0]

In [None]:
# validate spaces
env = Rbc(env_config=env_config)
print(
    "action space type:",
    type(env.action_space.sample()),
    "action space sample:",
    env.action_space.sample(),
)
print(
    "obs space type:",
    type(env.observation_space.sample()),
    "obs space sample:",
    env.observation_space.sample(),
)
obs_init = env.reset()
print(
    "obs_init contained in obs_space?",
    env.observation_space.contains(obs_init),
)
if not env.observation_space.contains(obs_init):
    print(obs_init)
print(
    "random number in [-1,1] contained in action_space?",
    env.action_space.contains(np.array([np.random.uniform(-1, 1)])),
)
obs, rew, done, info = env.step(env.action_space.sample())
print(
    "obs after step contained in obs space?",
    env.observation_space.contains(obs),
)

In [None]:
data_timing = {
    "time_init": [],
    "time_reset": [],
    "time_step": [],
    "max_passthrough": [],
}

time_preinit = time.time()
env = Rbc(env_config=env_config)
time_postinit = time.time()
env.reset()
time_postreset = time.time()
obs, rew, done, info = env.step(np.array([np.random.uniform(-1, 1)]))
time_poststep = time.time()

data_timing["time_init"].append((time_postinit - time_preinit) * 1000)
data_timing["time_reset"].append((time_postreset - time_postinit) * 1000)
data_timing["time_step"].append((time_poststep - time_postreset) * 1000)
data_timing["max_passthrough"].append(1 / (time_poststep - time_postreset))
print(data_timing)

In [4]:
# simulate
SIMUL_PERIODS = 1000000
env = Rbc(env_config=env_config)
print("steady_state", env.k_ss)
cap_stats, rew_stats, rew_disc_stats = env.random_sample(SIMUL_PERIODS)
print(
    "[cap_max, cap_min, cap_mean, cap_std]:",
    cap_stats,
    "\n" + "[rew_max, rew_min, rew_mean, rew_std:]",
    rew_stats,
    "\n" + "[rew_disc_max, rew_disc_min, rew_disc_mean, rew_disc_std:]",
    rew_disc_stats,
)

steady_state 37.989253538152255
[cap_max, cap_min, cap_mean, cap_std]: [67.1015911841568, 17.10507043429091, 38.886988509332156, 8.63749487933472] 
[rew_max, rew_min, rew_mean, rew_std:] [1.5681634546338266, 0.022497834130237468, 0.9138102408757718, 0.2939973479751011] 
[rew_disc_max, rew_disc_min, rew_disc_mean, rew_disc_std:] [76.31057981929418, 72.65286798907721, 74.78982161488845, 0.4214407435202891]


In [None]:
# run analysis mode
env_config_analysis = env_config.copy()
env_config_analysis["analysis_mode"] = True
env = Rbc(env_config=env_config_analysis)
k_list = []
rew_list = []
shock_list = []

env.reset()
for t in range(1000):
    if t % 1000 == 0:
        obs = env.reset()
    obs, rew, done, info = env.step(env.action_space.sample())
    shock_list.append(env.obs_global[1])
    k_list.append(info["capital"])
    rew_list.append(info["rewards"])
disc_rew = process_rewards(rew_list, 0.99)
print(
    "Discounted Rewards",
    disc_rew,
    "\n" + "cap_stats:",
    [
        np.max(k_list),
        np.min(k_list),
        np.mean(k_list),
        np.std(k_list),
    ],
    "\n" + "reward_stats:",
    [np.max(rew_list), np.min(rew_list), np.mean(rew_list), np.std(rew_list)],
)
plt.plot(shock_list)
plt.legend(["shock"])
plt.show()

In [None]:
# run evaluation mode
env_config_eval = env_config.copy()
env_config_eval["eval_mode"] = True
env_config_eval["simul_mode"] = True
env = Rbc(env_config=env_config_eval)
k_list = []
rew_list = []
shock_list = []

env.reset()
for t in range(200):
    if t % 200 == 0:
        obs = env.reset()
    obs, rew, done, info = env.step(env.action_space.sample())
    # print(obs, "\n", rew, "\n", done, "\n", info)

    k_list.append(info["capital"])
    shock_list.append(env.obs_global[1])
    rew_list.append(info["rewards"])
print(
    "cap_stats:",
    [
        np.max(k_list),
        np.min(k_list),
        np.mean(k_list),
        np.std(k_list),
    ],
    "reward_stats:",
    [np.max(rew_list), np.min(rew_list), np.mean(rew_list), np.std(rew_list)],
)

plt.plot(shock_list)
plt.legend(["shock"])
plt.show()