In [1]:
%matplotlib inline

from hydra import initialize, compose
from omegaconf import OmegaConf 
import pickle as pkl
import pandas as pd

import dr_gen.utils.run as ru
import dr_gen.utils.display as dsp

from dr_gen.analyze.run_group import RunGroup
import dr_gen.analyze.result_plotting as rplt

%load_ext autoreload
%autoreload 2

### Setup Config and Generator

In [2]:
with initialize(config_path="../configs/", version_base=None):
    cfg = compose(
        config_name="config.yaml",
        overrides=[
            "paths=mac",
        ]
    )

In [3]:
generator = ru.set_deterministic(cfg.seed)

In [4]:
print(OmegaConf.to_yaml(OmegaConf.to_container(cfg.paths, resolve=True)))

root: /Users/daniellerothermel/drotherm
proj_dir_name: dr_gen
data: /Users/daniellerothermel/drotherm/data
logs: /Users/daniellerothermel/drotherm/logs
my_data: /Users/daniellerothermel/drotherm/data/dr_gen
my_logs: /Users/daniellerothermel/drotherm/logs/dr_gen
run_dir: /Users/daniellerothermel/drotherm/logs/dr_gen/bs500/lr0.1/wd0.0001/s0/2025-04-10/16-27-1744316850
dataset_cache_root: /Users/daniellerothermel/drotherm/data/cifar10/
agg_results: /Users/daniellerothermel/drotherm/data/dr_gen/cifar10/cluster_runs/lr_wd_init_v0



### Load, Disect and Filter Sweep

In [5]:
rg = RunGroup()
rg.load_runs_from_base_dir(cfg.paths.agg_results)

>> 0 / 1288 files failed parsing
>> Updated hpm sweep info


In [6]:
rg.ignore_runs_by_hpms(epochs=180)

>> Ignoring rid: 1287
>> Updated hpm sweep info


In [7]:
print(dsp.make_table(*rg.get_swept_table_data()))

+------+------------+
| Key  |   Values   |
+------+------------+
| Init | pretrained |
|      |   random   |
+------+------------+
|  WD  |  0.00025   |
|      |   4e-05    |
|      |   0.0001   |
|      |  6.3e-05   |
|      |   1e-05    |
|      |  0.00016   |
+------+------------+
|  LR  |    0.1     |
|      |    0.25    |
|      |    0.06    |
|      |    0.04    |
|      |    0.16    |
|      |    0.2     |
|      |    0.01    |
+------+------------+


In [8]:
table = dsp.make_table(*rg.get_hpms_sweep_table())
print(">> Current Sweep, Ready to Analyze:")
dsp.print_table(
    table,
    drop_cols=[],
    sort_cols=['Init', 'LR', 'WD'],
    lr=[0.04, 0.06, 0.1, 0.16, 0.25],
)

>> Current Sweep, Ready to Analyze:
+------------+------+---------+-------+
|    Init    |  LR  |    WD   | Count |
+------------+------+---------+-------+
| pretrained | 0.04 |  0.0001 |   20  |
| pretrained | 0.04 | 0.00016 |   20  |
| pretrained | 0.04 | 0.00025 |   20  |
| pretrained | 0.04 |  4e-05  |   20  |
| pretrained | 0.04 | 6.3e-05 |   20  |
| pretrained | 0.06 |  0.0001 |   20  |
| pretrained | 0.06 | 0.00016 |   20  |
| pretrained | 0.06 | 0.00025 |   20  |
| pretrained | 0.06 |  4e-05  |   20  |
| pretrained | 0.06 | 6.3e-05 |   20  |
| pretrained | 0.1  |  0.0001 |  103  |
| pretrained | 0.1  | 0.00016 |   20  |
| pretrained | 0.1  | 0.00025 |   20  |
| pretrained | 0.1  |  4e-05  |   20  |
| pretrained | 0.1  | 6.3e-05 |   20  |
| pretrained | 0.16 |  0.0001 |   20  |
| pretrained | 0.16 | 0.00016 |   20  |
| pretrained | 0.16 | 0.00025 |   20  |
| pretrained | 0.16 |  4e-05  |   20  |
| pretrained | 0.16 | 6.3e-05 |   20  |
| pretrained | 0.25 |  0.0001 |   20  |
| pr

In [9]:
runs_pre = rg.select_run_data_by_hpms(lr=0.1, wd=1e-4, init="pretrained")
for hpm, rlist in runs_pre.items():
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.0001           | 103 RIDS


In [10]:
runs_rand = rg.select_run_data_by_hpms(**{"optim.lr": 0.1, "optim.weight_decay": 1e-4, "init": "random"})
for hpm, rlist in runs_rand.items():
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=None optim.lr=0.1 optim.weight_decay=0.0001              | 99 RIDS


In [11]:
runs_pre = rg.select_run_data_by_hpms(
    lr=0.04, init="pretrained",
)
for hpm, rlist in sorted([(str(hpm), rlist) for hpm, rlist in runs_pre.items()]):
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=0.0001          | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=0.00016         | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=0.00025         | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=4e-05           | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=6.3e-05         | 20 RIDS


In [12]:
runs_pre = rg.select_run_data_by_hpms(
    lr=0.04, init="random",
)
for hpm, rlist in sorted([(str(hpm), rlist) for hpm, rlist in runs_pre.items()]):
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=None optim.lr=0.04 optim.weight_decay=0.0001             | 20 RIDS
 - model.weights=None optim.lr=0.04 optim.weight_decay=0.00016            | 20 RIDS
 - model.weights=None optim.lr=0.04 optim.weight_decay=0.00025            | 20 RIDS
 - model.weights=None optim.lr=0.04 optim.weight_decay=4e-05              | 20 RIDS
 - model.weights=None optim.lr=0.04 optim.weight_decay=6.3e-05            | 20 RIDS


In [13]:
runs_pre = rg.select_run_data_by_hpms(
    lr=0.1, init="pretrained",
)
for hpm, rlist in sorted([(str(hpm), rlist) for hpm, rlist in runs_pre.items()]):
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.0001           | 103 RIDS
 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.00016          | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.00025          | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=4e-05            | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=6.3e-05          | 20 RIDS


In [14]:
runs_pre = rg.select_run_data_by_hpms(
    lr=0.1, init="random",
)
for hpm, rlist in sorted([(str(hpm), rlist) for hpm, rlist in runs_pre.items()]):
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=None optim.lr=0.1 optim.weight_decay=0.0001              | 99 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=0.00016             | 20 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=0.00025             | 20 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=1e-05               | 20 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=4e-05               | 20 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=6.3e-05             | 20 RIDS


## Test Result Plotting

### Test Without HPM Select

In [15]:
hpm_specs_one_each = rplt.make_hpm_specs()

In [16]:
compare_stats_one_each = rplt.one_tn_no_hpm_select_compare_weight_init(
    rg, hpm_specs_one_each, 260, 80, num_bootstraps=1000, split="val",
)
#rplt.print_comparative_summary_stats(compare_stats_one_each)

In [17]:
rplt.print_comparative_summary_stats(compare_stats_one_each)

>> :: Per Dist Summary Stats ::

Compare using 1000 bootstraps:
  - [(80, 260) | best: 259] model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.0001
  - [(80, 260) | best: 259] model.weights=None optim.lr=0.1 optim.weight_decay=0.0001

point
   n          | 80.00000000 | 80.00000000
   mean       | 85.10608847 | 84.42351505
   median     | 85.21789306 | 84.39847893
   min        | 82.00630962 | 82.95413746
   max        | 87.70262908 | 85.43163779
   variance   | 2.08910490 | 0.21388980
   std        | 1.44244867 | 0.45950765
   sem        | 0.16127066 | 0.05137452
   2.5th      | 82.29602312 | 83.58533237
   25th       | 84.08384090 | 84.15027920
   75th       | 86.23335887 | 84.73942000
   97.5th     | 87.33398071 | 85.28615248
   IQR        | 2.14951797 | 0.58914080
std
   n          | 0.00000000 | 0.00000000
   mean       | 0.16398235 | 0.05297842
   median     | 0.26518330 | 0.05869820
   min        | 0.16344764 | 0.51306682
   max        | 0.23278091 | 0.08185878
   variance 

### Test with HPM Select

In [18]:
hpm_specs_hpm_select = rplt.make_hpm_specs(
    lr=[0.04, 0.06, 0.1, 0.16, 0.25],
    wd=[1e-05, 4e-05, 6.3e-05, 0.0001, 0.00016, 0.00025],
    epochs=270,
)
print(hpm_specs_hpm_select)

{'optim.lr': [0.04, 0.06, 0.1, 0.16, 0.25], 'optim.weight_decay': [1e-05, 4e-05, 6.3e-05, 0.0001, 0.00016, 0.00025], 'epochs': 270}


In [19]:
cshs_with_md = rplt.one_tn_hpm_compare_weight_init(
    rg, hpm_specs_hpm_select, 270, 100, num_bootstraps=1000,
)

In [20]:
md_b_ex = list(cshs_with_md['metadata_b'].items())
for hpm, hpm_md in sorted(md_b_ex, key=lambda x: str(x[0])):
    print(f'{str(hpm):80}', hpm_md['best_t'], hpm_md['best_t_mean'])

model.weights=None optim.lr=0.04 optim.weight_decay=0.0001                       243 83.5912271572113
model.weights=None optim.lr=0.04 optim.weight_decay=0.00016                      218 83.82748880157472
model.weights=None optim.lr=0.04 optim.weight_decay=0.00025                      241 84.25949689025879
model.weights=None optim.lr=0.04 optim.weight_decay=4e-05                        193 83.25663178558351
model.weights=None optim.lr=0.04 optim.weight_decay=6.3e-05                      261 83.50766263313295
model.weights=None optim.lr=0.06 optim.weight_decay=0.0001                       235 84.19552797698975
model.weights=None optim.lr=0.06 optim.weight_decay=0.00016                      262 84.71238954658509
model.weights=None optim.lr=0.06 optim.weight_decay=0.00025                      243 85.24682160377502
model.weights=None optim.lr=0.06 optim.weight_decay=4e-05                        231 83.77230150032044
model.weights=None optim.lr=0.06 optim.weight_decay=6.3e-05               

In [21]:
rplt.print_comparative_summary_stats(cshs_with_md)

>> :: Per Dist Summary Stats ::

Compare using 1000 bootstraps:
  - [(20, 270) | best: 232] model.weights=DEFAULT optim.lr=0.06 optim.weight_decay=0.00025
  - [(20, 270) | best: 267] model.weights=None optim.lr=0.25 optim.weight_decay=0.00025

point
   n          | 20.00000000 | 20.00000000
   mean       | 86.96023288 | 86.03562231
   median     | 86.99662774 | 86.09379286
   min        | 85.94562947 | 84.83297978
   max        | 87.55638788 | 86.64050728
   variance   | 0.18670629 | 0.24217205
   std        | 0.42271431 | 0.47362362
   sem        | 0.09452179 | 0.10590546
   2.5th      | 86.11155134 | 85.05195906
   25th       | 86.76435237 | 85.83871880
   75th       | 87.24982299 | 86.33519309
   97.5th     | 87.52142765 | 86.61771242
   IQR        | 0.48547062 | 0.49647429
std
   n          | 0.00000000 | 0.00000000
   mean       | 0.09695866 | 0.11160185
   median     | 0.10071940 | 0.12913410
   min        | 0.37036688 | 0.56231124
   max        | 0.07283645 | 0.05155814
   varia

### Old Stuff

In [233]:
compare_stats_hpm_select = rplt.one_tn_hpm_compare_weight_init(
    rg, hpm_specs_hpm_select, 270, 100, num_bootstraps=1000,
)

In [234]:
rplt.print_comparative_summary_stats(compare_stats_hpm_select)

>> :: Per Dist Summary Stats ::

Compare using 1000 bootstraps:
  - [(20, 270) | best: 232] model.weights=DEFAULT optim.lr=0.06 optim.weight_decay=0.00025
  - [(20, 270) | best: 264] model.weights=None optim.lr=0.25 optim.weight_decay=0.00025

point
   n          | 20.00000000 | 20.00000000
   mean       | 86.95341741 | 86.00767885
   median     | 86.98694271 | 86.05517912
   min        | 85.96465945 | 84.80876588
   max        | 87.55188786 | 86.60768640
   variance   | 0.18639283 | 0.26309991
   std        | 0.42105925 | 0.50213843
   sem        | 0.09415171 | 0.11228157
   2.5th      | 86.11560463 | 84.98695334
   25th       | 86.75398495 | 85.76950394
   75th       | 87.24069806 | 86.39156330
   97.5th     | 87.51628165 | 86.59580253
   IQR        | 0.48671311 | 0.62205936
std
   n          | 0.00000000 | 0.00000000
   mean       | 0.09434803 | 0.11430604
   median     | 0.09243138 | 0.14226848
   min        | 0.37862054 | 0.42244576
   max        | 0.07644215 | 0.02316272
   varia

In [233]:
compare_stats_hpm_select = rplt.one_tn_hpm_compare_weight_init(
    rg, hpm_specs_hpm_select, 270, 100, num_bootstraps=1000,
)

### Get Run Metrics for All and Convert to CSV

In [112]:
compare_data_path = '/Users/daniellerothermel/Desktop/hpm_select_bootstrap_compare_data_t270_n100_v0.pkl'
original_runs_csv_path = '/Users/daniellerothermel/Desktop/t270_n100_runs_data.csv'
bootstrap_samples_csv_path = '/Users/daniellerothermel/Desktop/t270_n100_runs_data__at__t270.csv'

In [16]:
hpm_specs_hpm_select = rplt.make_hpm_specs(
    lr=[0.04, 0.06, 0.1, 0.16, 0.25],
    wd=[1e-05, 4e-05, 6.3e-05, 0.0001, 0.00016, 0.00025],
    epochs=270,
)
print(hpm_specs_hpm_select)

{'optim.lr': [0.04, 0.06, 0.1, 0.16, 0.25], 'optim.weight_decay': [1e-05, 4e-05, 6.3e-05, 0.0001, 0.00016, 0.00025], 'epochs': 270}


In [46]:
hpms_val_pre, hpms_val_rand = rplt.get_pretrained_vs_random_init_runs(rg, hpm_specs_hpm_select, "val", one_per=False)
all_hpms = {str(hpm): data for hpm, data in {**hpms_val_pre, **hpms_val_rand}.items()}
list_vals = False
if list_vals:
    for hpm, rlist in sorted(list(all_hpms.items())):
        print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

In [116]:
next(iter(hpms_val_pre.items()))

(<dr_gen.analyze.run_data.Hpm at 0x138eeb830>,
 [[69.43000030517578,
   75.19999694824219,
   79.12999725341797,
   79.22999572753906,
   76.52999877929688,
   79.72000122070312,
   80.29999542236328,
   80.20999908447266,
   82.1500015258789,
   81.08000183105469,
   82.79000091552734,
   83.25999450683594,
   82.04000091552734,
   81.62999725341797,
   83.22000122070312,
   82.97999572753906,
   83.23999786376953,
   84.11000061035156,
   81.93999481201172,
   83.62999725341797,
   83.91999816894531,
   84.61000061035156,
   83.29999542236328,
   83.88999938964844,
   83.08000183105469,
   84.7699966430664,
   84.79999542236328,
   83.23999786376953,
   83.0199966430664,
   83.29000091552734,
   83.83000183105469,
   84.55999755859375,
   83.72000122070312,
   83.63999938964844,
   84.22999572753906,
   83.80999755859375,
   84.50999450683594,
   83.1500015258789,
   83.58000183105469,
   84.22000122070312,
   84.5199966430664,
   84.45999908447266,
   83.8499984741211,
   84.0499954

In [76]:
all_hpms_np = rplt.runs_metrics_dict_to_ndarray_dict(all_hpms)

In [77]:
all_hpms_df = rplt.run_groups_to_df(all_hpms_np, 270)

In [78]:
all_hpms_df

Unnamed: 0,group_name,run_index,step_index,metric_value,model.weights,optim.lr,optim.weight_decay
0,model.weights=DEFAULT optim.lr=0.04 optim.weig...,0,0,69.430000,DEFAULT,0.04,0.000063
1,model.weights=DEFAULT optim.lr=0.04 optim.weig...,0,1,75.199997,DEFAULT,0.04,0.000063
2,model.weights=DEFAULT optim.lr=0.04 optim.weig...,0,2,79.129997,DEFAULT,0.04,0.000063
3,model.weights=DEFAULT optim.lr=0.04 optim.weig...,0,3,79.229996,DEFAULT,0.04,0.000063
4,model.weights=DEFAULT optim.lr=0.04 optim.weig...,0,4,76.529999,DEFAULT,0.04,0.000063
...,...,...,...,...,...,...,...
319135,model.weights=None optim.lr=0.25 optim.weight_...,19,265,84.430000,NONE,0.25,0.000040
319136,model.weights=None optim.lr=0.25 optim.weight_...,19,266,84.419998,NONE,0.25,0.000040
319137,model.weights=None optim.lr=0.25 optim.weight_...,19,267,84.419998,NONE,0.25,0.000040
319138,model.weights=None optim.lr=0.25 optim.weight_...,19,268,84.430000,NONE,0.25,0.000040


### Get Bootstrapped Results and Convert to CSV

In [108]:
bootstrap_samples = rplt.bootstrap_samples_dict(all_hpms, b=1000, select_t=270)

min_num_runs=20


In [109]:
bootstrap_samples_df = rplt.run_groups_to_df(
    {k: v.transpose() for k, v in bootstrap_samples.items()},
    1000,
)

In [111]:
bootstrap_samples_df['run_index'].max(), bootstrap_samples_df['step_index'].max()

(np.int64(19), np.int64(999))

In [113]:
bootstrap_samples_df.to_csv(bootstrap_samples_csv_path)

### Get Summary Stats for Bootstrap Data

In [53]:
cshs_with_md.keys()

dict_keys(['metadata_a', 'metadata_b', 'best_hpm_a', 'best_t_a', 'runs_shape_a', 'best_hpm_b', 'best_t_b', 'runs_shape_b', 'summary_stats_a', 'summary_stats_b', 'comparative_stats'])

In [56]:
ssa = list(cshs_with_md['summary_stats_a'].items())

In [58]:
cshs_with_md['summary_stats_a'].keys()

dict_keys(['num_bootstraps', 'bootstrap_dists', 'point', 'std', 'sem', 'ci_95', 'original_data', 'original_data_stats'])

In [59]:
cshs_with_md['summary_stats_a']['bootstrap_dists'].keys()

dict_keys(['sorted_vals', 'n', 'mean', 'median', 'min', 'max', 'variance', 'std', 'sem', '2.5th', '25th', '75th', '97.5th', 'IQR'])

In [65]:
cshs_with_md['summary_stats_a']['bootstrap_dists'].keys()

dict_keys(['sorted_vals', 'n', 'mean', 'median', 'min', 'max', 'variance', 'std', 'sem', '2.5th', '25th', '75th', '97.5th', 'IQR'])

In [60]:
cshs_with_md['summary_stats_a']['point']

{'n': 20.0,
 'mean': 86.96023288154602,
 'median': 86.99662774276733,
 'min': 85.9456294708252,
 'max': 87.55638787841796,
 'variance': 0.18670629216167925,
 'std': 0.4227143132308622,
 'sem': 0.09452179394463466,
 '2.5th': 86.1115513355255,
 '25th': 86.76435236740112,
 '75th': 87.24982299232482,
 '97.5th': 87.52142765083313,
 'IQR': 0.4854706249237061}

In [61]:
cshs_with_md['summary_stats_a']['ci_95']

{'n': (20.0, 20.0),
 'mean': (86.7654858493805, 87.14357284545899),
 'median': (86.82462196350097, 87.27512273788452),
 'min': (85.68000030517578, 86.63999938964844),
 'max': (87.31999969482422, 87.5999984741211),
 'variance': (0.06753137466327308, 0.3498306366259778),
 'std': (0.2598679836143949, 0.5914647993309764),
 'sem': (0.05810824765375884, 0.13225554976023354),
 '2.5th': (85.68000030517578, 86.64949779510498),
 '25th': (86.56881217956543, 86.97999572753906),
 '75th': (86.98999786376953, 87.43749618530273),
 '97.5th': (87.31999969482422, 87.5999984741211),
 'IQR': (0.222503662109375, 0.7301888942718504)}

In [23]:
cshs_with_md.keys()

dict_keys(['metadata_a', 'metadata_b', 'best_hpm_a', 'best_t_a', 'runs_shape_a', 'best_hpm_b', 'best_t_b', 'runs_shape_b', 'summary_stats_a', 'summary_stats_b', 'comparative_stats'])

In [27]:
for k, v in rpl.items():
    print(k, v.shape)

model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=6.3e-05 (20, 270)
model.weights=DEFAULT optim.lr=0.25 optim.weight_decay=0.00016 (20, 270)
model.weights=DEFAULT optim.lr=0.16 optim.weight_decay=6.3e-05 (20, 270)
model.weights=DEFAULT optim.lr=0.06 optim.weight_decay=0.00025 (20, 270)
model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=4e-05 (20, 270)
model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.0001 (103, 270)
model.weights=DEFAULT optim.lr=0.06 optim.weight_decay=0.00016 (20, 270)
model.weights=DEFAULT optim.lr=0.06 optim.weight_decay=6.3e-05 (20, 270)
model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.00016 (20, 270)
model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=0.00025 (20, 270)
model.weights=DEFAULT optim.lr=0.25 optim.weight_decay=4e-05 (20, 270)
model.weights=DEFAULT optim.lr=0.25 optim.weight_decay=0.0001 (20, 270)
model.weights=DEFAULT optim.lr=0.16 optim.weight_decay=0.00016 (20, 270)
model.weights=DEFAULT optim.lr=0.06 optim.weight_decay=4e-0