In [4]:
import os
import pandas as pd
from glob import glob
from humemai.utils import read_yaml

print(f"Found {len(glob('./training-results/hp-tuning-s/*/*.pt'))} models")
table = []
for model_path in glob("./training-results/hp-tuning-s/*/*.pt"):
    dir_name = os.path.dirname(model_path)
    train_path = os.path.join(dir_name, "train.yaml")
    results_path = os.path.join(dir_name, "results.yaml")

    train = read_yaml(train_path)

    # if "episode=0" in model_path or "episode=1" in model_path:
    #     continue  # skip the weird ones

    try:
        results = read_yaml(results_path)
    except FileNotFoundError:
        print(f"Skipping {model_path} because results.yaml not found")
        continue
    table.append(
        {
            "test_score": results["test_score"]["mean"],
            "batch_size": train["batch_size"],
            "warm_start": train["warm_start"],
            "replay_buffer_size": train["replay_buffer_size"],
            "gamma": train["gamma"],
            "target_update_interval": train["target_update_interval"],
            "ddqn": train["ddqn"],
            "dueling_dqn": train["dueling_dqn"],
            "model_path": os.path.basename(model_path),
            "dir_name": dir_name.split("/")[-1],
        }
    )


print(f"Found {len(table)} models with okay results")

df = pd.DataFrame(table)
df = df.sort_values("test_score", ascending=False)
df[:10]

Found 187 models
Skipping ./training-results/hp-tuning-s/2023-11-14 18:26:11.639586/episode=2_val-score=49.pt because results.yaml not found
Found 186 models with okay results


Unnamed: 0,test_score,batch_size,warm_start,replay_buffer_size,gamma,target_update_interval,ddqn,dueling_dqn,model_path,dir_name
42,97.6,512,4096,8192,0.218268,43,True,True,episode=1_val-score=100.pt,2023-11-13 18:31:39.499406
182,94.8,256,1024,2048,0.992601,99,False,False,episode=4_val-score=95.pt,2023-11-13 21:38:47.405336
164,94.6,256,512,4096,0.650934,19,True,False,episode=1_val-score=96.pt,2023-11-13 16:01:51.568960
68,94.2,512,4096,8192,0.492716,73,True,False,episode=3_val-score=95.pt,2023-11-14 15:21:23.042208
87,94.2,512,1024,2048,0.817004,71,False,True,episode=3_val-score=98.pt,2023-11-13 20:10:57.478853
78,93.8,128,128,512,0.924394,66,False,True,episode=13_val-score=97.pt,2023-11-14 02:47:54.735724
151,93.0,256,2048,8192,0.930493,65,False,True,episode=6_val-score=96.pt,2023-11-14 00:46:48.746845
10,93.0,256,256,2048,0.08151,44,False,True,episode=0_val-score=98.pt,2023-11-14 12:18:53.999387
55,92.4,256,2048,16384,0.507841,49,True,False,episode=8_val-score=96.pt,2023-11-13 16:40:40.392867
83,91.2,128,512,512,0.47838,90,True,False,episode=9_val-score=92.pt,2023-11-14 07:36:39.216014


In [17]:
df.batch_size.mode(), df.warm_start.mode(), df.replay_buffer_size.mode(), df.gamma.mean(), df.target_update_interval.mean(), df.ddqn.mode(), df.dueling_dqn.mode()

(0    128
 Name: batch_size, dtype: int64,
 0    1024
 Name: warm_start, dtype: int64,
 0    4096
 Name: replay_buffer_size, dtype: int64,
 0.5229346023350929,
 67.20967741935483,
 0    False
 Name: ddqn, dtype: bool,
 0    False
 Name: dueling_dqn, dtype: bool)