In [1]:
import json
import os
import pandas as pd
import ray

from pprint import pprint 

from env import QdTreeEnv
from qdtree import Workload
from qdtree.schema import ensure_data_schema
from policy import QdTreePolicy

from ray.rllib.env.env_context import EnvContext
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray import air, tune


Parameters

In [2]:
# We use sf = 1 for TPC-H. The paper uses sf = 1000
BASE_PATH = "../data/tpc-h/"
# The sampling rate in the paper is 0.1
SAMPLING_RATE = 0.01
# The min block size in the paper is 100K, we use 100 to scale down according to the sf
MIN_BLOCK_SIZE = 100 * SAMPLING_RATE

# Load the workload

In [3]:
WORKLOAD_PATH = os.path.join(BASE_PATH, "workload.json")
with open(WORKLOAD_PATH, "r") as f:
    workload = Workload(json.load(f))

workload._queries

{'1.1.sql': l_shipdate <= 903247200,
 '1.5.sql': l_shipdate <= 905666400,
 '1.3.sql': l_shipdate <= 907048800,
 '1.2.sql': l_shipdate <= 902556000,
 '1.4.sql': l_shipdate <= 906357600,
 '3.1.sql':  and(o_orderdate < 796287600, l_shipdate > 796287600),
 '3.4.sql':  and(o_orderdate < 795250800, l_shipdate > 795250800),
 '3.2.sql':  and(o_orderdate < 795078000, l_shipdate > 795078000),
 '3.3.sql':  and(o_orderdate < 796460400, l_shipdate > 796460400),
 '3.5.sql':  and(o_orderdate < 796633200, l_shipdate > 796633200),
 '4.3.sql':  and(o_orderdate >= 865144800, o_orderdate < 873093600),
 '4.2.sql':  and(o_orderdate >= 783673200, o_orderdate < 791622000),
 '4.1.sql':  and(o_orderdate >= 854780400, o_orderdate < 862466400),
 '4.5.sql':  and(o_orderdate >= 875685600, o_orderdate < 883638000),
 '4.4.sql':  and(o_orderdate >= 794041200, o_orderdate < 801986400),
 '5.4.sql':  and(o_orderdate >= 725871600, o_orderdate < 757407600),
 '5.2.sql':  and(o_orderdate >= 725871600, o_orderdate < 757407600

# Load and sample data

In [4]:
DATA_PATH = os.path.join(BASE_PATH, "sf1/denormalized.parquet")
all_data = pd.read_parquet(DATA_PATH).sample(frac=0.1)
all_data.columns = all_data.columns.str.lower()
all_data = ensure_data_schema(all_data, workload.schema)
all_data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 600122 entries, 269621 to 5769231
Data columns (total 35 columns):
 #   Column            Non-Null Count   Dtype  
---  ------            --------------   -----  
 0   l_orderkey        600122 non-null  int64  
 1   l_partkey         600122 non-null  int64  
 2   l_suppkey         600122 non-null  int64  
 3   l_linenumber      600122 non-null  int64  
 4   l_quantity        600122 non-null  float64
 5   l_extendedprice   600122 non-null  float64
 6   l_discount        600122 non-null  float64
 7   l_tax             600122 non-null  float64
 8   l_shipdate        600122 non-null  int64  
 9   l_commitdate      600122 non-null  int64  
 10  l_receiptdate     600122 non-null  int64  
 11  o_orderkey        600122 non-null  int64  
 12  o_custkey         600122 non-null  int64  
 13  o_totalprice      600122 non-null  float64
 14  o_orderdate       600122 non-null  int64  
 15  o_shippriority    600122 non-null  int64  
 16  c_custkey     

In [5]:
data = all_data.sample(frac=SAMPLING_RATE)
len(data)

6001

# Set up the environment

In [6]:
ray.init(local_mode=False, ignore_reinit_error=True,num_cpus=8)

2024-05-19 19:24:24,948	INFO worker.py:1553 -- Started a local Ray instance.


0,1
Python version:,3.10.14
Ray version:,2.3.1


In [7]:
env_config = {
    "workload": workload,
    "data": data,
    "min_leaf_size": MIN_BLOCK_SIZE,
}

config = (
    PPOConfig()
    .environment(
        QdTreeEnv,
        env_config=env_config,
    )
    .framework("torch")
    .rollouts(num_rollout_workers=2, batch_mode="complete_episodes")
    .training(
        model={
            "fcnet_hiddens": [512, 512],
            "fcnet_activation": "relu",
        }
    )
)

algo = config.build()

stop = {
    "training_iteration": 10,
    "timesteps_total": 40000,
}

2024-05-19 19:24:33,745	INFO algorithm.py:506 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


[2m[36m(PPO pid=13860)[0m 2024-05-19 19:24:42,604	INFO algorithm.py:506 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.


In [8]:
os.chdir('/Users/umsaka/Documents/DDCPS/qdtree/src')

In [9]:
tuner = tune.Tuner(
    "PPO",
    param_space=config.to_dict(),
    run_config=air.RunConfig(
        local_dir="/Users/umsaka/Documents/DDCPS/qdtree/src/results",
        stop=stop,
        checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True),
    ),
)

result = tuner.fit()

0,1
Current time:,2024-05-19 19:26:15
Running for:,00:01:35.31
Memory:,15.2/24.0 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
PPO_QdTreeEnv_bad10_00000,TERMINATED,127.0.0.1:13860,10,90.2455,40270,0.303888,0.48514,0.129307,15.6641


Trial name,agent_timesteps_total,connector_metrics,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
PPO_QdTreeEnv_bad10_00000,40270,"{'ObsPreprocessorConnector_ms': 0.0023609958589076996, 'StateBufferConnector_ms': 0.002108979970216751, 'ViewRequirementAgentConnector_ms': 0.04327511414885521}","{'num_env_steps_sampled': 40270, 'num_env_steps_trained': 40270, 'num_agent_steps_sampled': 40270, 'num_agent_steps_trained': 40270}",{},2024-05-19_19-26-15,True,15.6641,{},0.48514,0.303888,0.129307,256,2398,dd58b9a45eae4e868e4fdda8b40f4a7f,Mete-New.local,"{'learner': {'default_policy': {'learner_stats': {'allreduce_latency': 0.0, 'grad_gnorm': 1.860827862831854, 'cur_kl_coeff': 0.6750000000000002, 'cur_lr': 5.0000000000000016e-05, 'total_loss': -0.06883565503553117, 'policy_loss': -0.09064365520983213, 'vf_loss': 0.00855401635360253, 'vf_explained_var': 0.033846884132713397, 'kl': 0.01963553222075952, 'entropy': 4.348961020541448, 'entropy_coeff': 0.0}, 'model': {}, 'custom_metrics': {}, 'num_agent_steps_trained': 128.0, 'num_grad_updates_lifetime': 8835.5, 'diff_num_grad_updates_vs_sampler_policy': 464.5}}, 'num_env_steps_sampled': 40270, 'num_env_steps_trained': 40270, 'num_agent_steps_sampled': 40270, 'num_agent_steps_trained': 40270}",10,127.0.0.1,40270,40270,40270,4010,40270,4010,0,2,0,0,4010,"{'cpu_util_percent': 21.40833333333333, 'ram_util_percent': 63.116666666666674}",13860,{},{},{},"{'mean_raw_obs_processing_ms': 0.15513650456068706, 'mean_inference_ms': 0.41910595969507736, 'mean_action_processing_ms': 0.04144881000508645, 'mean_env_wait_ms': 0.5871797523901627, 'mean_env_render_ms': 0.0}","{'episode_reward_max': 0.4851396562111443, 'episode_reward_min': 0.12930665401920194, 'episode_reward_mean': 0.30388759325368847, 'episode_len_mean': 15.6640625, 'episode_media': {}, 'episodes_this_iter': 256, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [0.33354440926512247, 0.3941625113763091, 0.3151936215758912, 0.19987694358632535, 0.2379065027623601, 0.36788996705677257, 0.14273518516144745, 0.2645353979464961, 0.3207593606193839, 0.32299744914309153, 0.2514939817722687, 0.19272941689205644, 0.4033609782984887, 0.20373271121479752, 0.31115839667747686, 0.3481829951418353, 0.37535282581108276, 0.3612244113160627, 0.2789176162947201, 0.37803955750964585, 0.2749541743042826, 0.31905195287964827, 0.29305115814030996, 0.2293899734659608, 0.24696909489444066, 0.3389435094150975, 0.17994180457103304, 0.17883942419853102, 0.2512068501403612, 0.2910540550933819, 0.29486880391729586, 0.38519990258033915, 0.3800366605565739, 0.14326586594541935, 0.3307269301270301, 0.2916437004089062, 0.2970889467140092, 0.16223449937830875, 0.4572699421891223, 0.38296950508248623, 0.3461730737184828, 0.39161678181841486, 0.18558189019778754, 0.36458026226398166, 0.32972709676592366, 0.3880840372758386, 0.27835360773204465, 0.16843859356773871, 0.28068142489072334, 0.33115506389960647, 0.38133388025072745, 0.29521746375604063, 0.3153269326907054, 0.3497417097150475, 0.43988822375757886, 0.24432338200043582, 0.3663697076128338, 0.32419212182584956, 0.4587543101790727, 0.3718072628920821, 0.27788445515491006, 0.3433632856062451, 0.2664145719303193, 0.25692640970094727, 0.3157422480868573, 0.3090895107225719, 0.319403176393678, 0.357381462064015, 0.3509902195787882, 0.2674374783689898, 0.2623511466037712, 0.34784715367951496, 0.22739030674374783, 0.2594747029341264, 0.48313229846307665, 0.25523182033763603, 0.4442208349890403, 0.3086998320792688, 0.3231102508556266, 0.33083460448899543, 0.3831181982490098, 0.25684437209183086, 0.4049530206504044, 0.13846153846153847, 0.361193647212644, 0.2650763334316076, 0.25777498622024536, 0.2995039288323741, 0.377690897670901, 0.42493430582082475, 0.30431851101739454, 0.3699639803622473, 0.3696537756527758, 0.18715598682270904, 0.2678245933370079, 0.3416763872687885, 0.3919295502031713, 0.3904374911873662, 0.2511581403099483, 0.3392870419032725, 0.3400048709830413, 0.2730750003204594, 0.3718995552023381, 0.231223001294656, 0.3295553305218361, 0.443354312742748, 0.24923538384628202, 0.3194416315229513, 0.4104367220847807, 0.3264763565046851, 0.3516003742965916, 0.39476497506825786, 0.4062861317985464, 0.3358645353979465, 0.37914193788214784, 0.15357184059067078, 0.26760668093779244, 0.32297437606552754, 0.2564085472933998, 0.262056323946009, 0.18761232102341918, 0.3330060374552959, 0.3872405881071104, 0.28661889685052494, 0.3970799738505121, 0.3591657800622973, 0.2902336790022176, 0.3671336828477305, 0.14326586594541935, 0.33806416879238077, 0.18761232102341918, 0.44040608616512633, 0.22897209439452398, 0.2487559765680079, 0.3388819812082602, 0.3467293912553036, 0.3263199723123069, 0.28300667837411714, 0.27419276274467075, 0.37119198082370886, 0.413256764898158, 0.22123748606001564, 0.14890338789688898, 0.24018304641534102, 0.27281863279197055, 0.33497494007409023, 0.257690384935844, 0.4567520797815749, 0.3435068514221989, 0.4034660889851691, 0.18951713176009127, 0.3941240562470357, 0.34939817722687244, 0.25963621447707436, 0.40611692922974374, 0.34974683706561727, 0.32602258597925987, 0.2794200966505582, 0.19784651276069373, 0.12930665401920194, 0.2869547383128453, 0.2650224962506249, 0.39198338738415395, 0.18715598682270904, 0.282734928793919, 0.3829566867060618, 0.27290836142694164, 0.2654455026726315, 0.24870470306231013, 0.3184238524348506, 0.29180777562713905, 0.20277646033353416, 0.30383397638855064, 0.29548152231038416, 0.26910130362888235, 0.2864753310345712, 0.2522528296565957, 0.23196133977670388, 0.2794098419494187, 0.3548024047274172, 0.4851396562111443, 0.26154102521374645, 0.3184802532911181, 0.3170266494045864, 0.2351710612333842, 0.263732967582326, 0.3922935920936254, 0.2374681142886442, 0.36105777242254494, 0.30010895619960776, 0.24696909489444066, 0.34096368553958956, 0.27892787099585964, 0.3453091151474754, 0.3710073962031969, 0.39894632945791086, 0.3239152448950816, 0.2762180662197326, 0.2799200133311115, 0.367536179867458, 0.1926755797110738, 0.16327022419340367, 0.2232858626126415, 0.240413777190981, 0.243331239665184, 0.3503134093035776, 0.2296335226180252, 0.14490661812774794, 0.4247881763295861, 0.43894991860330973, 0.2981708176842321, 0.25447809980387887, 0.1754810095753272, 0.3593426736569546, 0.28387832797097917, 0.2733570046017971, 0.2909771448348352, 0.16355735582531117, 0.3881019830028329, 0.3003448143258175, 0.2942125030443644, 0.2858985040954713, 0.2929742478817633, 0.22398061861484625, 0.31077384538474356, 0.386743235101842, 0.29639675438708935, 0.19339084511555765, 0.2698191327086511, 0.38686629151551666, 0.35672259751579866, 0.3524335687641803, 0.34437080999320624, 0.30442105802879005, 0.26147180598105446, 0.19966159486239474, 0.21701254919051952, 0.38799687231615243, 0.343411995436658, 0.2863779113737454, 0.4037198928383731, 0.3324702293207542, 0.21858664581544102, 0.32702498301565125, 0.2985322958994014, 0.23665542922333457, 0.246317921372079, 0.391016881801751, 0.30415956314973147, 0.31175573301885584, 0.1708971581659467, 0.2774793944598977, 0.3221668183507877, 0.3603040518887878, 0.279104764590517, 0.2407880737825747], 'episode_lengths': [24, 20, 10, 8, 12, 30, 4, 16, 8, 20, 16, 6, 24, 4, 12, 12, 24, 18, 24, 16, 24, 14, 8, 12, 18, 8, 6, 6, 6, 46, 14, 28, 12, 4, 14, 14, 10, 14, 22, 16, 16, 22, 4, 32, 8, 26, 6, 6, 6, 10, 12, 16, 8, 14, 20, 8, 20, 10, 32, 20, 40, 30, 28, 12, 22, 14, 18, 12, 12, 6, 14, 18, 8, 26, 36, 8, 16, 12, 16, 24, 48, 14, 26, 4, 16, 16, 16, 18, 22, 20, 22, 14, 18, 4, 12, 34, 18, 20, 16, 22, 14, 14, 18, 20, 42, 18, 6, 22, 28, 26, 28, 28, 26, 8, 34, 10, 14, 8, 6, 6, 4, 8, 16, 16, 12, 18, 6, 24, 4, 24, 4, 16, 16, 26, 12, 20, 10, 14, 6, 10, 22, 6, 6, 20, 14, 6, 6, 26, 12, 18, 6, 8, 36, 14, 16, 20, 22, 22, 24, 4, 10, 14, 12, 4, 14, 16, 10, 14, 8, 16, 22, 6, 18, 18, 16, 12, 6, 14, 8, 22, 18, 8, 18, 12, 10, 8, 26, 16, 16, 10, 10, 10, 16, 14, 24, 40, 10, 34, 12, 24, 6, 14, 12, 16, 6, 24, 6, 6, 22, 30, 8, 6, 18, 8, 14, 8, 8, 4, 32, 16, 8, 18, 22, 8, 22, 18, 18, 6, 18, 10, 14, 22, 16, 18, 8, 12, 18, 20, 12, 14, 24, 12, 8, 18, 16, 6, 14, 22, 22, 14, 4, 24, 16, 20, 14, 12]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 0.15513650456068706, 'mean_inference_ms': 0.41910595969507736, 'mean_action_processing_ms': 0.04144881000508645, 'mean_env_wait_ms': 0.5871797523901627, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0, 'connector_metrics': {'ObsPreprocessorConnector_ms': 0.0023609958589076996, 'StateBufferConnector_ms': 0.002108979970216751, 'ViewRequirementAgentConnector_ms': 0.04327511414885521}}",90.2455,8.85839,90.2455,"{'training_iteration_time_ms': 9020.378, 'load_time_ms': 1.276, 'load_throughput': 3156977.722, 'learn_time_ms': 6542.961, 'learn_throughput': 615.471, 'synch_weights_time_ms': 2.453}",1716168375,0,40270,10,bad10_00000,2.46925


2024-05-19 19:26:15,822	INFO tune.py:798 -- Total run time: 95.44 seconds (95.23 seconds for the tuning loop).


In [10]:
agent = config.build()
agent.restore(result.get_best_result().checkpoint) # type: ignore

env = QdTreeEnv(EnvContext(env_config, 0))
done = False
obs, _ = env.reset() 
step = 0
episode_reward = 0
while not done:
    action = agent.compute_single_action(obs, explore=False)
    obs, reward, done, _, info = env.step(action) # type: ignore
    episode_reward += reward
    # print(step, action, obs, reward, done, info)
    step += 1

pprint(env.qd_tree.blocks)
print(episode_reward)


2024-05-19 19:30:03,694	INFO trainable.py:791 -- Restored on 127.0.0.1 from checkpoint: /Users/umsaka/Documents/DDCPS/qdtree/src/results/PPO/PPO_QdTreeEnv_bad10_00000_0_2024-05-19_19-24-40/checkpoint_000010
2024-05-19 19:30:03,694	INFO trainable.py:800 -- Current state after restoring: {'_iteration': 10, '_timesteps_total': None, '_time_total': 90.24550175666809, '_episodes_total': 2398}


[{'c_acctbal': (-inf, inf),
  'c_custkey': (-inf, inf),
  'c_nationkey': (-inf, inf),
  'l_commitdate': (-inf, inf),
  'l_discount': (-inf, inf),
  'l_extendedprice': (-inf, inf),
  'l_linenumber': (-inf, inf),
  'l_orderkey': (-inf, inf),
  'l_partkey': (-inf, inf),
  'l_quantity': (-inf, inf),
  'l_receiptdate': (-inf, inf),
  'l_shipdate': [788943600, 807256800),
  'l_suppkey': (-inf, inf),
  'l_tax': (-inf, inf),
  'n_nationkey_cust': (-inf, inf),
  'n_nationkey_supp': (-inf, inf),
  'n_regionkey_cust': (-inf, inf),
  'n_regionkey_supp': (-inf, inf),
  'o_custkey': (-inf, inf),
  'o_orderdate': [796460400, inf),
  'o_orderkey': (-inf, inf),
  'o_shippriority': (-inf, inf),
  'o_totalprice': (-inf, inf),
  'p_partkey': (-inf, inf),
  'p_retailprice': (-inf, inf),
  'p_size': (-inf, inf),
  'ps_availqty': (-inf, inf),
  'ps_partkey': (-inf, inf),
  'ps_suppkey': (-inf, inf),
  'ps_supplycost': (-inf, inf),
  'r_regionkey_cust': (-inf, inf),
  'r_regionkey_supp': (-inf, inf),
  's_acc

In [11]:
print(env.qd_tree)

{'id': 1,
 'cut': l_shipdate < 788943600,
 'size': 6001,
 'block': {'l_orderkey': (-inf, inf),
           'l_partkey': (-inf, inf),
           'l_suppkey': (-inf, inf),
           'l_linenumber': (-inf, inf),
           'l_quantity': (-inf, inf),
           'l_extendedprice': (-inf, inf),
           'l_discount': (-inf, inf),
           'l_tax': (-inf, inf),
           'l_shipdate': (-inf, inf),
           'l_commitdate': (-inf, inf),
           'l_receiptdate': (-inf, inf),
           'o_orderkey': (-inf, inf),
           'o_custkey': (-inf, inf),
           'o_totalprice': (-inf, inf),
           'o_orderdate': (-inf, inf),
           'o_shippriority': (-inf, inf),
           'c_custkey': (-inf, inf),
           'c_nationkey': (-inf, inf),
           'c_acctbal': (-inf, inf),
           'n_nationkey_cust': (-inf, inf),
           'n_regionkey_cust': (-inf, inf),
           'r_regionkey_cust': (-inf, inf),
           'ps_partkey': (-inf, inf),
           'ps_suppkey': (-inf, inf),
   