In [1]:
import os
import pandas as pd
import numpy as np
import ray

from env import QdTreeEnv
from qdtree import Workload
from policy import QdTreePolicy

from ray.rllib.env.env_context import EnvContext
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray import air, tune

In [2]:
workload = Workload({
    "schema": {
        "col": "int",
    },
    "queries": {
        k: {
            "type": "expr",
            "children": ["col", "<=", str(v)]
        } for k, v in enumerate(range(600, 1001, 50))
    }
})

data = pd.DataFrame({
    "col": np.arange(10000),
})

workload._queries

{0: col <= 600,
 1: col <= 650,
 2: col <= 700,
 3: col <= 750,
 4: col <= 800,
 5: col <= 850,
 6: col <= 900,
 7: col <= 950,
 8: col <= 1000}

In [3]:
ray.init(local_mode=False)

2023-05-06 20:27:30,744	INFO worker.py:1553 -- Started a local Ray instance.


0,1
Python version:,3.8.10
Ray version:,2.3.1


In [41]:
MIN_LEAF_SIZE = 200

env_config = {
    "workload": workload,
    "data": data,
    "min_leaf_size": MIN_LEAF_SIZE,
}

config = (
    PPOConfig()
    .environment(
        QdTreeEnv,
        env_config=env_config,
    )
    .framework("torch")
    .rollouts(num_rollout_workers=1, batch_mode="complete_episodes")
    .training(
        model={
            "fcnet_hiddens": [128],
            "fcnet_activation": "relu",
            # "vf_share_layers": True,
        }
    )
    .multi_agent(
        policies={DEFAULT_POLICY_ID: PolicySpec(policy_class=QdTreePolicy)}
    )
    # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
    .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)

algo = config.build()

stop = {
    "training_iteration": 10,
    "timesteps_total": 10000,
    "episode_reward_mean": 0.95,
}





In [None]:
# from ray.tune.logger import pretty_print

# config.lr = 1e-3
# algo = config.build()
# # run manual training loop and print results after each iteration
# for _ in range(stop["training_iteration"]):
#     result = algo.train()
#     print(pretty_print(result))
#     # stop training of the target train steps or reward are reached
#     if (
#         result["timesteps_total"] >= stop["timesteps_total"]
#         or result["episode_reward_mean"] >= stop["episode_reward_mean"]
#     ):
#         break
# algo.stop()


In [42]:
tuner = tune.Tuner(
    "PPO",
    param_space=config.to_dict(),
    run_config=air.RunConfig(
        local_dir="./results",
        stop=stop,
        checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True),
    ),
)
result = tuner.fit()

0,1
Current time:,2023-05-06 21:10:18
Running for:,00:00:50.37
Memory:,7.8/31.3 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
PPO_QdTreeEnv_d08b8_00000,TERMINATED,192.168.75.74:15072,3,43.4941,12004,0.730018,0.913233,0.104433,5.27704


[2m[36m(PPO pid=15072)[0m 2023-05-06 21:09:31,596	INFO algorithm.py:506 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


Trial name,agent_timesteps_total,connector_metrics,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
PPO_QdTreeEnv_d08b8_00000,12004,"{'ObsPreprocessorConnector_ms': 0.007602564577691474, 'StateBufferConnector_ms': 0.010262472962955686, 'ViewRequirementAgentConnector_ms': 0.1285584746690413}","{'num_env_steps_sampled': 12004, 'num_env_steps_trained': 12004, 'num_agent_steps_sampled': 12004, 'num_agent_steps_trained': 12004}",{},2023-05-06_21-10-18,True,5.27704,{},0.913233,0.730018,0.104433,758,2398,970b4720716f432f87d944fddad4da0a,HOMELAB,"{'learner': {'default_policy': {'learner_stats': {'allreduce_latency': 0.0, 'grad_gnorm': 1.7551520158526719, 'cur_kl_coeff': 0.4500000000000001, 'cur_lr': 5.0000000000000016e-05, 'total_loss': 0.4144470496043082, 'policy_loss': -0.0395048292433863, 'vf_loss': 0.4452335016060901, 'vf_explained_var': 0.38137393516878926, 'kl': 0.01937417322210922, 'entropy': 2.0354708151150773, 'entropy_coeff': 0.0}, 'model': {}, 'custom_metrics': {}, 'num_agent_steps_trained': 128.0, 'num_grad_updates_lifetime': 2325.5, 'diff_num_grad_updates_vs_sampler_policy': 464.5}}, 'num_env_steps_sampled': 12004, 'num_env_steps_trained': 12004, 'num_agent_steps_sampled': 12004, 'num_agent_steps_trained': 12004}",3,192.168.75.74,12004,12004,12004,4000,12004,4000,0,1,0,0,4000,"{'cpu_util_percent': 5.1000000000000005, 'ram_util_percent': 24.899999999999995}",15072,{},{},{},"{'mean_raw_obs_processing_ms': 0.5699220273654593, 'mean_inference_ms': 1.0940420285010026, 'mean_action_processing_ms': 0.12318812922009424, 'mean_env_wait_ms': 0.5210910623940466, 'mean_env_render_ms': 0.0}","{'episode_reward_max': 0.9132333333333333, 'episode_reward_min': 0.10443333333333334, 'episode_reward_mean': 0.7300181031955437, 'episode_len_mean': 5.277044854881266, 'episode_media': {}, 'episodes_this_iter': 758, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [0.10443333333333334, 0.9076777777777778, 0.9110111111111111, 0.10443333333333334, 0.6127111111111111, 0.8043555555555556, 0.8126888888888889, 0.8126888888888889, 0.8132444444444444, 0.9110111111111111, 0.30996666666666667, 0.30996666666666667, 0.9076777777777778, 0.8126888888888889, 0.5110555555555556, 0.8043555555555556, 0.7110333333333333, 0.8110222222222222, 0.8126888888888889, 0.9076777777777778, 0.7143666666666667, 0.9110111111111111, 0.8132444444444444, 0.9110111111111111, 0.6099333333333333, 0.8082444444444444, 0.30996666666666667, 0.20775555555555555, 0.6143777777777778, 0.7077, 0.10443333333333334, 0.8999, 0.5132777777777778, 0.20775555555555555, 0.9110111111111111, 0.7110333333333333, 0.9110111111111111, 0.8043555555555556, 0.20775555555555555, 0.9099, 0.8110222222222222, 0.8043555555555556, 0.8126888888888889, 0.8126888888888889, 0.6099333333333333, 0.7132555555555555, 0.8126888888888889, 0.7132555555555555, 0.9099, 0.8043555555555556, 0.9076777777777778, 0.8082444444444444, 0.8999, 0.7110333333333333, 0.8999, 0.8110222222222222, 0.30996666666666667, 0.6143777777777778, 0.7143666666666667, 0.9076777777777778, 0.7143666666666667, 0.7132555555555555, 0.8999, 0.9076777777777778, 0.8043555555555556, 0.9099, 0.30996666666666667, 0.9110111111111111, 0.8043555555555556, 0.9076777777777778, 0.10443333333333334, 0.9099, 0.8082444444444444, 0.10443333333333334, 0.9099, 0.8110222222222222, 0.8999, 0.7077, 0.7110333333333333, 0.7077, 0.5110555555555556, 0.8999, 0.8132444444444444, 0.10443333333333334, 0.9132333333333333, 0.8126888888888889, 0.9076777777777778, 0.7132555555555555, 0.9110111111111111, 0.8110222222222222, 0.9043444444444444, 0.8110222222222222, 0.9099, 0.9099, 0.30996666666666667, 0.8110222222222222, 0.8110222222222222, 0.8082444444444444, 0.9043444444444444, 0.8043555555555556, 0.20775555555555555, 0.30996666666666667, 0.8043555555555556, 0.6099333333333333, 0.9110111111111111, 0.6143777777777778, 0.7077, 0.8043555555555556, 0.20775555555555555, 0.9076777777777778, 0.6099333333333333, 0.9110111111111111, 0.8999, 0.8043555555555556, 0.9076777777777778, 0.7077, 0.9076777777777778, 0.6127111111111111, 0.9076777777777778, 0.8110222222222222, 0.8043555555555556, 0.6099333333333333, 0.9110111111111111, 0.9076777777777778, 0.8043555555555556, 0.9132333333333333, 0.9099, 0.8999, 0.8110222222222222, 0.7110333333333333, 0.9110111111111111, 0.9110111111111111, 0.4110666666666667, 0.7110333333333333, 0.6143777777777778, 0.9076777777777778, 0.9110111111111111, 0.8999, 0.8126888888888889, 0.9099, 0.20775555555555555, 0.9132333333333333, 0.7132555555555555, 0.7143666666666667, 0.9076777777777778, 0.9043444444444444, 0.8043555555555556, 0.7143666666666667, 0.7110333333333333, 0.8126888888888889, 0.7077, 0.7143666666666667, 0.9099, 0.7077, 0.6099333333333333, 0.8126888888888889, 0.8110222222222222, 0.8110222222222222, 0.8110222222222222, 0.7143666666666667, 0.30996666666666667, 0.8082444444444444, 0.6099333333333333, 0.8126888888888889, 0.8132444444444444, 0.9076777777777778, 0.7077, 0.9043444444444444, 0.8126888888888889, 0.4110666666666667, 0.10443333333333334, 0.9076777777777778, 0.7110333333333333, 0.9099, 0.5110555555555556, 0.9043444444444444, 0.7143666666666667, 0.9076777777777778, 0.4110666666666667, 0.9110111111111111, 0.20775555555555555, 0.9043444444444444, 0.8999, 0.10443333333333334, 0.8126888888888889, 0.8043555555555556, 0.8126888888888889, 0.8043555555555556, 0.7132555555555555, 0.8126888888888889, 0.8110222222222222, 0.8126888888888889, 0.7132555555555555, 0.8132444444444444, 0.4110666666666667, 0.9110111111111111, 0.7143666666666667, 0.7132555555555555, 0.6143777777777778, 0.5110555555555556, 0.9099, 0.8999, 0.7132555555555555, 0.8126888888888889, 0.30996666666666667, 0.9099, 0.6127111111111111, 0.10443333333333334, 0.20775555555555555, 0.7110333333333333, 0.4110666666666667, 0.7110333333333333, 0.8132444444444444, 0.6127111111111111, 0.9043444444444444, 0.9132333333333333, 0.8999, 0.8999, 0.8110222222222222, 0.7077, 0.9099, 0.8082444444444444, 0.4110666666666667, 0.8999, 0.7077, 0.8132444444444444, 0.9099, 0.8110222222222222, 0.7077, 0.9110111111111111, 0.8043555555555556, 0.9076777777777778, 0.5110555555555556, 0.8110222222222222, 0.20775555555555555, 0.7077, 0.9099, 0.7110333333333333, 0.8110222222222222, 0.5110555555555556, 0.9110111111111111, 0.9110111111111111, 0.9076777777777778, 0.7077, 0.7143666666666667, 0.8999, 0.7077, 0.8043555555555556, 0.5110555555555556, 0.8110222222222222, 0.7077, 0.7077, 0.8110222222222222, 0.8110222222222222, 0.8043555555555556, 0.8043555555555556, 0.9076777777777778, 0.9043444444444444, 0.9043444444444444, 0.8082444444444444, 0.7110333333333333, 0.8043555555555556, 0.9099, 0.30996666666666667, 0.30996666666666667, 0.20775555555555555, 0.8043555555555556, 0.7110333333333333, 0.9076777777777778, 0.9099, 0.8999, 0.20775555555555555, 0.8043555555555556, 0.8082444444444444, 0.9110111111111111, 0.9043444444444444, 0.7077, 0.20775555555555555, 0.8999, 0.8110222222222222, 0.7143666666666667, 0.9099, 0.10443333333333334, 0.6127111111111111, 0.8110222222222222, 0.6127111111111111, 0.9099, 0.7077, 0.9043444444444444, 0.8999, 0.8043555555555556, 0.8132444444444444, 0.7132555555555555, 0.8110222222222222, 0.9110111111111111, 0.9099, 0.30996666666666667, 0.7143666666666667, 0.8082444444444444, 0.20775555555555555, 0.7077, 0.8999, 0.8110222222222222, 0.9076777777777778, 0.8082444444444444, 0.7110333333333333, 0.8082444444444444, 0.20775555555555555, 0.7110333333333333, 0.6099333333333333, 0.9110111111111111, 0.8082444444444444, 0.8043555555555556, 0.6143777777777778, 0.8999, 0.4110666666666667, 0.8126888888888889, 0.7132555555555555, 0.8043555555555556, 0.6143777777777778, 0.8132444444444444, 0.5110555555555556, 0.7132555555555555, 0.7110333333333333, 0.7110333333333333, 0.10443333333333334, 0.6127111111111111, 0.8126888888888889, 0.9110111111111111, 0.6143777777777778, 0.7132555555555555, 0.7143666666666667, 0.9099, 0.5110555555555556, 0.9043444444444444, 0.7110333333333333, 0.9099, 0.30996666666666667, 0.7110333333333333, 0.6099333333333333, 0.7143666666666667, 0.9076777777777778, 0.7077, 0.9110111111111111, 0.8999, 0.8110222222222222, 0.9043444444444444, 0.9076777777777778, 0.4110666666666667, 0.6143777777777778, 0.8126888888888889, 0.8110222222222222, 0.8126888888888889, 0.9099, 0.8126888888888889, 0.8126888888888889, 0.30996666666666667, 0.6099333333333333, 0.30996666666666667, 0.7110333333333333, 0.9099, 0.9043444444444444, 0.6099333333333333, 0.9099, 0.6099333333333333, 0.9099, 0.8999, 0.9132333333333333, 0.7077, 0.8999, 0.7143666666666667, 0.9043444444444444, 0.7143666666666667, 0.8082444444444444, 0.8043555555555556, 0.8126888888888889, 0.8110222222222222, 0.5110555555555556, 0.8132444444444444, 0.8999, 0.7110333333333333, 0.8110222222222222, 0.9076777777777778, 0.9132333333333333, 0.8043555555555556, 0.8043555555555556, 0.8999, 0.8110222222222222, 0.8082444444444444, 0.5110555555555556, 0.8126888888888889, 0.6099333333333333, 0.8126888888888889, 0.7110333333333333, 0.8110222222222222, 0.9110111111111111, 0.8082444444444444, 0.7143666666666667, 0.8082444444444444, 0.6099333333333333, 0.8043555555555556, 0.5132777777777778, 0.7132555555555555, 0.9110111111111111, 0.9099, 0.8043555555555556, 0.6127111111111111, 0.9076777777777778, 0.8043555555555556, 0.4110666666666667, 0.8082444444444444, 0.6099333333333333, 0.7143666666666667, 0.9110111111111111, 0.8110222222222222, 0.7110333333333333, 0.9099, 0.9043444444444444, 0.8110222222222222, 0.7077, 0.9076777777777778, 0.6099333333333333, 0.8999, 0.7110333333333333, 0.8132444444444444, 0.9076777777777778, 0.8110222222222222, 0.9043444444444444, 0.9043444444444444, 0.8126888888888889, 0.9043444444444444, 0.9099, 0.8999, 0.6127111111111111, 0.30996666666666667, 0.8132444444444444, 0.20775555555555555, 0.30996666666666667, 0.6127111111111111, 0.8043555555555556, 0.8126888888888889, 0.6099333333333333, 0.30996666666666667, 0.9132333333333333, 0.8082444444444444, 0.8082444444444444, 0.5110555555555556, 0.9043444444444444, 0.9110111111111111, 0.9099, 0.7143666666666667, 0.8126888888888889, 0.9099, 0.5132777777777778, 0.8999, 0.6143777777777778, 0.8082444444444444, 0.8126888888888889, 0.8126888888888889, 0.9110111111111111, 0.20775555555555555, 0.9099, 0.8110222222222222, 0.9076777777777778, 0.8043555555555556, 0.8082444444444444, 0.8110222222222222, 0.20775555555555555, 0.9076777777777778, 0.8132444444444444, 0.8043555555555556, 0.8999, 0.20775555555555555, 0.8082444444444444, 0.8043555555555556, 0.4110666666666667, 0.9076777777777778, 0.8126888888888889, 0.8999, 0.9043444444444444, 0.8110222222222222, 0.8043555555555556, 0.9043444444444444, 0.8132444444444444, 0.4110666666666667, 0.6127111111111111, 0.4110666666666667, 0.9099, 0.7132555555555555, 0.7110333333333333, 0.9043444444444444, 0.8043555555555556, 0.6127111111111111, 0.9110111111111111, 0.6143777777777778, 0.8082444444444444, 0.20775555555555555, 0.8126888888888889, 0.8082444444444444, 0.8999, 0.20775555555555555, 0.20775555555555555, 0.6099333333333333, 0.9110111111111111, 0.8999, 0.7132555555555555, 0.4110666666666667, 0.8082444444444444, 0.8132444444444444, 0.8110222222222222, 0.30996666666666667, 0.30996666666666667, 0.9110111111111111, 0.6143777777777778, 0.8043555555555556, 0.7132555555555555, 0.9076777777777778, 0.8043555555555556, 0.9099, 0.8043555555555556, 0.8126888888888889, 0.8110222222222222, 0.30996666666666667, 0.8082444444444444, 0.9110111111111111, 0.7132555555555555, 0.8999, 0.7110333333333333, 0.6143777777777778, 0.7110333333333333, 0.7132555555555555, 0.7077, 0.9043444444444444, 0.8082444444444444, 0.7143666666666667, 0.8043555555555556, 0.8999, 0.7077, 0.20775555555555555, 0.6127111111111111, 0.8082444444444444, 0.9099, 0.8110222222222222, 0.8082444444444444, 0.7132555555555555, 0.7110333333333333, 0.5132777777777778, 0.6099333333333333, 0.8043555555555556, 0.10443333333333334, 0.7143666666666667, 0.8999, 0.6099333333333333, 0.9110111111111111, 0.8082444444444444, 0.8126888888888889, 0.7077, 0.30996666666666667, 0.8082444444444444, 0.7077, 0.30996666666666667, 0.8043555555555556, 0.9110111111111111, 0.7077, 0.8082444444444444, 0.6143777777777778, 0.8082444444444444, 0.9099, 0.9099, 0.10443333333333334, 0.7077, 0.8043555555555556, 0.5110555555555556, 0.7143666666666667, 0.7077, 0.9043444444444444, 0.7132555555555555, 0.9099, 0.9076777777777778, 0.5132777777777778, 0.9043444444444444, 0.9110111111111111, 0.7077, 0.9076777777777778, 0.6099333333333333, 0.4110666666666667, 0.7110333333333333, 0.9043444444444444, 0.4110666666666667, 0.6099333333333333, 0.8110222222222222, 0.9110111111111111, 0.8110222222222222, 0.8126888888888889, 0.7077, 0.7132555555555555, 0.6099333333333333, 0.9099, 0.8126888888888889, 0.9043444444444444, 0.8110222222222222, 0.8082444444444444, 0.30996666666666667, 0.6143777777777778, 0.9099, 0.9110111111111111, 0.30996666666666667, 0.5110555555555556, 0.9076777777777778, 0.8999, 0.8082444444444444, 0.9076777777777778, 0.8126888888888889, 0.9043444444444444, 0.8043555555555556, 0.7132555555555555, 0.7077, 0.30996666666666667, 0.8999, 0.9043444444444444, 0.20775555555555555, 0.9099, 0.8110222222222222, 0.7077, 0.9043444444444444, 0.8126888888888889, 0.8110222222222222, 0.20775555555555555, 0.9043444444444444, 0.30996666666666667, 0.8082444444444444, 0.9099, 0.9043444444444444, 0.8126888888888889, 0.8043555555555556, 0.7110333333333333, 0.9076777777777778, 0.8132444444444444, 0.9110111111111111, 0.9110111111111111, 0.9110111111111111, 0.7110333333333333, 0.8043555555555556, 0.8999, 0.4110666666666667, 0.8126888888888889, 0.6127111111111111, 0.8043555555555556, 0.9110111111111111, 0.7143666666666667, 0.8043555555555556, 0.9043444444444444, 0.9043444444444444, 0.9099, 0.7132555555555555, 0.7077, 0.9110111111111111, 0.7132555555555555, 0.8043555555555556, 0.7132555555555555, 0.6127111111111111, 0.5132777777777778, 0.9076777777777778, 0.7077, 0.8110222222222222, 0.9099, 0.7077, 0.7110333333333333, 0.30996666666666667, 0.8043555555555556, 0.9076777777777778, 0.9076777777777778, 0.30996666666666667, 0.8043555555555556, 0.8999, 0.8043555555555556, 0.8082444444444444, 0.7132555555555555, 0.20775555555555555, 0.9076777777777778, 0.5110555555555556, 0.9043444444444444, 0.10443333333333334, 0.8126888888888889, 0.9110111111111111, 0.9099, 0.9110111111111111, 0.9043444444444444, 0.8110222222222222, 0.8043555555555556, 0.10443333333333334, 0.5110555555555556, 0.4110666666666667, 0.8082444444444444, 0.6099333333333333, 0.9110111111111111, 0.8999, 0.4110666666666667, 0.8126888888888889, 0.8999, 0.5110555555555556, 0.6143777777777778, 0.8132444444444444, 0.8082444444444444, 0.8082444444444444, 0.9043444444444444, 0.5110555555555556, 0.7132555555555555, 0.8132444444444444, 0.9043444444444444, 0.6099333333333333, 0.20775555555555555, 0.8110222222222222, 0.8043555555555556, 0.6143777777777778, 0.8043555555555556, 0.8110222222222222, 0.9132333333333333, 0.9099, 0.8110222222222222, 0.7110333333333333, 0.20775555555555555, 0.30996666666666667, 0.7132555555555555, 0.8999, 0.8082444444444444, 0.6143777777777778, 0.8110222222222222, 0.6099333333333333, 0.8043555555555556, 0.7132555555555555, 0.8043555555555556, 0.9076777777777778, 0.9043444444444444, 0.8110222222222222, 0.5110555555555556, 0.8999, 0.6127111111111111, 0.6099333333333333, 0.20775555555555555, 0.20775555555555555, 0.8043555555555556, 0.8043555555555556, 0.8082444444444444, 0.9076777777777778, 0.9099, 0.8999, 0.8043555555555556, 0.7143666666666667, 0.9076777777777778, 0.4110666666666667, 0.9099, 0.9043444444444444, 0.7143666666666667, 0.9110111111111111, 0.10443333333333334, 0.8082444444444444, 0.8082444444444444], 'episode_lengths': [4, 6, 6, 4, 6, 4, 6, 6, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 4, 6, 4, 4, 6, 4, 4, 4, 6, 4, 6, 6, 6, 4, 4, 6, 6, 4, 6, 6, 4, 6, 6, 6, 6, 4, 6, 6, 4, 6, 4, 6, 4, 6, 6, 6, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 4, 6, 6, 4, 6, 6, 4, 4, 6, 4, 4, 4, 6, 4, 8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 6, 6, 6, 6, 4, 4, 4, 4, 4, 6, 6, 4, 4, 4, 6, 4, 6, 4, 4, 6, 4, 6, 6, 6, 6, 4, 4, 6, 6, 4, 8, 6, 4, 6, 6, 6, 6, 4, 6, 6, 6, 6, 4, 6, 6, 4, 8, 6, 6, 6, 6, 4, 6, 6, 6, 4, 6, 6, 4, 4, 6, 6, 6, 6, 6, 4, 6, 4, 6, 6, 6, 4, 6, 6, 4, 4, 6, 6, 6, 4, 6, 6, 6, 4, 6, 4, 6, 4, 4, 6, 4, 6, 4, 6, 6, 6, 6, 6, 6, 4, 6, 6, 6, 6, 4, 6, 4, 6, 6, 4, 6, 6, 4, 4, 6, 4, 6, 6, 6, 6, 8, 4, 4, 6, 4, 6, 6, 4, 4, 4, 6, 6, 6, 4, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 4, 6, 6, 6, 4, 6, 4, 4, 4, 4, 6, 4, 4, 6, 6, 4, 4, 6, 6, 6, 6, 6, 4, 6, 4, 4, 4, 4, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 6, 6, 6, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 4, 6, 6, 4, 4, 4, 6, 6, 6, 6, 6, 4, 6, 4, 6, 6, 4, 6, 4, 4, 6, 6, 4, 6, 6, 4, 6, 6, 6, 4, 6, 6, 6, 6, 6, 6, 6, 4, 6, 6, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 6, 6, 4, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 6, 4, 6, 4, 8, 4, 4, 6, 6, 6, 6, 4, 6, 6, 4, 6, 4, 6, 6, 6, 8, 4, 4, 4, 6, 6, 4, 6, 4, 6, 6, 6, 6, 6, 6, 6, 4, 4, 6, 6, 6, 6, 4, 6, 6, 4, 4, 6, 4, 6, 6, 6, 6, 6, 6, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 4, 6, 4, 6, 4, 4, 6, 4, 6, 4, 4, 8, 6, 6, 4, 6, 6, 6, 6, 6, 6, 6, 4, 6, 6, 6, 6, 6, 4, 6, 6, 6, 4, 6, 6, 4, 6, 6, 4, 4, 4, 6, 4, 4, 6, 6, 4, 6, 6, 4, 6, 6, 4, 6, 4, 6, 6, 6, 6, 4, 6, 6, 6, 6, 4, 6, 6, 4, 4, 4, 4, 6, 4, 6, 4, 6, 6, 6, 4, 4, 6, 6, 4, 6, 6, 4, 6, 4, 6, 6, 4, 6, 6, 6, 4, 6, 6, 6, 6, 4, 6, 6, 6, 4, 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 6, 4, 4, 6, 6, 6, 4, 4, 6, 4, 4, 4, 6, 4, 6, 6, 6, 6, 6, 4, 4, 4, 4, 6, 4, 6, 6, 6, 6, 6, 6, 6, 4, 6, 4, 4, 6, 6, 4, 4, 6, 6, 6, 6, 4, 6, 4, 6, 6, 6, 6, 6, 4, 6, 6, 6, 4, 4, 6, 4, 6, 6, 6, 6, 4, 6, 4, 4, 4, 6, 4, 6, 6, 4, 6, 6, 6, 4, 6, 4, 6, 6, 6, 6, 4, 6, 6, 6, 6, 6, 6, 6, 4, 4, 4, 6, 6, 4, 6, 6, 4, 6, 6, 6, 6, 4, 6, 6, 4, 6, 6, 6, 6, 4, 6, 6, 4, 6, 4, 4, 6, 6, 4, 4, 4, 4, 6, 6, 4, 6, 4, 6, 4, 6, 6, 6, 6, 6, 6, 4, 4, 4, 4, 6, 4, 6, 4, 4, 6, 4, 4, 6, 6, 6, 6, 6, 4, 6, 6, 6, 4, 4, 6, 4, 6, 4, 6, 8, 6, 6, 6, 4, 4, 6, 4, 6, 6, 6, 4, 4, 6, 4, 6, 6, 6, 4, 4, 6, 4, 4, 4, 4, 4, 6, 6, 6, 4, 4, 6, 6, 4, 6, 6, 6, 6, 4, 6, 6]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.5699220273654593, 'mean_inference_ms': 1.0940420285010026, 'mean_action_processing_ms': 0.12318812922009424, 'mean_env_wait_ms': 0.5210910623940466, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 0.007602564577691474, 'StateBufferConnector_ms': 0.010262472962955686, 'ViewRequirementAgentConnector_ms': 0.1285584746690413}}",43.4941,14.2937,43.4941,"{'training_iteration_time_ms': 14467.87, 'load_time_ms': 0.925, 'load_throughput': 4326953.009, 'learn_time_ms': 4982.953, 'learn_throughput': 803.004, 'synch_weights_time_ms': 1.327}",1683421818,0,12004,3,d08b8_00000,3.26467


In [15]:
agent = config.build()


2023-05-06 20:48:55,575	INFO trainable.py:791 -- Restored on 192.168.75.74 from checkpoint: /home/ctring/src/qdtree/src/results/PPO/PPO_QdTreeEnv_1da3e_00000_0_2023-05-06_20-28-40/checkpoint_000001
2023-05-06 20:48:55,576	INFO trainable.py:800 -- Current state after restoring: {'_iteration': 1, '_timesteps_total': None, '_time_total': 15.46816635131836, '_episodes_total': 1000}


In [46]:
agent.restore(result.get_best_result().checkpoint) # type: ignore

env = QdTreeEnv(EnvContext(env_config, 0))
done = False
obs, _ = env.reset() 
step = 0
episode_reward = 0
while not done:
    action = agent.compute_single_action(obs, explore=False)
    obs, reward, done, _, info = env.step(action) # type: ignore
    episode_reward += reward
    print(step, action, obs, reward, done, info)
    step += 1

print(episode_reward)


2023-05-06 21:14:18,115	INFO trainable.py:791 -- Restored on 192.168.75.74 from checkpoint: /home/ctring/src/qdtree/src/results/PPO/PPO_QdTreeEnv_d08b8_00000_0_2023-05-06_21-09-28/checkpoint_000003
2023-05-06 21:14:18,116	INFO trainable.py:800 -- Current state after restoring: {'_iteration': 3, '_timesteps_total': None, '_time_total': 43.49412679672241, '_episodes_total': 2398}


0 8 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0] 0 False {}
1 0 [0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 0 False {}
2 8 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 0] 0 False {}
3 0 [0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0] 0 False {}
4 6 [0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0] 0.9043444444444444 False {'rewards': array([0.90434444, 0.04440004, 1.        , 0.        , 0.11111111,
       0.        ])}
5 6 [0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0] 0 True {}
0.9043444444444444


In [47]:
print(env.qd_tree)

{'id': 1,
 'cut': col <= 1000,
 'size': 10000,
 'block': {'col': (-inf, inf)},
 'skipped_records': 81391,
 'left': {'id': 2,
          'cut': col <= 600,
          'size': 1001,
          'block': {'col': (-inf, 1000]},
          'skipped_records': 400,
          'left': {'id': 4,
                   'cut': None,
                   'size': 601,
                   'block': {'col': (-inf, 600]},
                   'skipped_records': 0,
                   'left': None,
                   'right': None},
          'right': {'id': 5,
                    'cut': None,
                    'size': 400,
                    'block': {'col': (600, 1000]},
                    'skipped_records': 400,
                    'left': None,
                    'right': None}},
 'right': {'id': 3,
           'cut': None,
           'size': 8999,
           'block': {'col': (1000, inf)},
           'skipped_records': 80991,
           'left': None,
           'right': None}}


In [44]:
env = QdTreeEnv(EnvContext(env_config, 0))
obs, _ = env.reset() 

agent.get_policy().compute_log_likelihoods(list(range(9)), [obs])


tensor([-1.9378, -1.8947, -2.0470, -2.9134, -3.5488, -3.3076, -2.6307, -1.9341,
        -1.4260])