In [1]:
import sys
import numpy as np
import gym
import ray
from ray.rllib.agents import ppo, a3c, cql, ddpg, dqn

gym.logger.set_level(40)

sys.path.append("..")

from source.envs.env import WhitedBasicModel
from source.solvers.ray_solver import RaySolver

%load_ext autoreload
%autoreload 1

A3C_Trainer = a3c.A3CTrainer
PPO_Trainer = ppo.PPOTrainer
DQNTrainer = dqn.DQNTrainer

In [2]:
from source.utils.useful_class import ParameterGrid

grid = {
        'delta': [0.1, 0.2, 0.3],
        'gamma': [1, 10],
    }
pg = ParameterGrid(grid)
for g in pg:
    print(g)
    

{'delta': 0.1, 'gamma': 1}
{'delta': 0.1, 'gamma': 10}
{'delta': 0.2, 'gamma': 1}
{'delta': 0.2, 'gamma': 10}
{'delta': 0.3, 'gamma': 1}
{'delta': 0.3, 'gamma': 10}


In [3]:
ray.shutdown()
ray.init()
env = WhitedBasicModel(env_config={"structural_params": {"gamma": [0.9,0.96],
                                                         "delta": [0.1, 0.3],
                                                         "theta": [0.5, 0.8],
                                                         "rho": [0.3, 0.8],
                                                         "sigma": [0., 0.15],
                                                        }, 
                                   "env_params": {"psi_func": lambda i, k: 0.01*i**2/(2*k)
                                                 },
                                   "is_mutable": True,
                                  })
solver = RaySolver(env=env,
                   trainer=A3C_Trainer,
                   solver_params={"verbose": True, "episodes": 10,
                                  "trainer_config": {
                                      "num_workers": 8,
                                      "gamma": env.current_structural_params.get("gamma", 0.99),
                                  }
                                  })
solver.train()
ray.shutdown()

2022-01-25 19:47:20,508	INFO services.py:1340 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m
2022-01-25 19:47:21,611	INFO trainer.py:745 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


agent_timesteps_total: 0
custom_metrics: {}
date: 2022-01-25_19-47-23
done: false
episode_len_mean: .nan
episode_media: {}
episode_reward_max: .nan
episode_reward_mean: .nan
episode_reward_min: .nan
episodes_this_iter: 0
episodes_total: 0
experiment_id: e81217aa91364ff488289d6a917b607b
hostname: mw-14.local
info:
  learner:
    default_policy:
      batch_count: 10
      learner_stats:
        allreduce_latency: 0.0
        cur_lr: 0.0001
        entropy_coeff: 0.01
        grad_gnorm: 168.23855590820312
        policy_entropy: 29.95711326599121
        policy_loss: -94.7829818725586
        vf_loss: 54.918212890625
  num_steps_sampled: 10
  num_steps_trained: 10
iterations_since_restore: 1
node_ip: 127.0.0.1
num_healthy_workers: 8
off_policy_estimator: {}
perf:
  cpu_util_percent: 49.5
  ram_util_percent: 89.4
pid: 18812
policy_reward_max: {}
policy_reward_mean: {}
policy_reward_min: {}
sampler_perf: {}
time_since_restore: 0.09767699241638184
time_this_iter_s: 0.09767699241638184
time

In [4]:
# solver.trainer.save()

In [5]:
# solver.trainer.restore('/Users/mingweima/ray_results/A3C_my-env_2022-01-25_18-51-04pgozg8qt/checkpoint_000010/checkpoint-10')

In [6]:
data = sample_obs = solver.sample(param_dict={"gamma": 0.9,
                             "delta": 0.15,
                             "theta": 0.5,
                             "rho": 0.5,
                             "sigma": 0.15,
                            })

In [49]:
estimator_params = 

(data=data,  # (nsamples, N, T) or (N, T); N: obs dim, T: eps length
                 solver=solver,
                 estimator_params: ):

(14,)


33575694.91361878

In [None]:
"""Test trained agent for a single episode. Return the episode reward"""
cp = []
for eps in range(2):
      # instantiate env class
      episode_reward = 0
      done = False
      obs = env.reset()
      # run until episode ends
      caps = []
      while not done:
          action = solver.trainer.compute_single_action(obs, clip_action=True)
          obs, reward, done, info = env.step(action, resample_param=False)
          episode_reward += reward
          #print(action, obs, reward, done)
          caps += [obs[0]]
      cp += [ [caps] ]
cp = np.squeeze(np.array(cp)).mean(axis=0)

In [None]:
import matplotlib.pyplot as plt
plt.plot(cp)
plt.show()