In [2]:
%matplotlib inline

from hydra import initialize, compose
from omegaconf import OmegaConf 

import dr_gen.utils.run as ru
import dr_gen.utils.display as dsp

from dr_gen.analyze.run_group import RunGroup
import dr_gen.analyze.result_plotting as rplt

%load_ext autoreload
%autoreload 2

### Setup Config and Generator

In [3]:
with initialize(config_path="../configs/", version_base=None):
    cfg = compose(
        config_name="config.yaml",
        overrides=[
            "paths=mac",
        ]
    )

In [4]:
generator = ru.set_deterministic(cfg.seed)

In [5]:
print(OmegaConf.to_yaml(OmegaConf.to_container(cfg.paths, resolve=True)))

root: /Users/daniellerothermel/drotherm
proj_dir_name: dr_gen
data: /Users/daniellerothermel/drotherm/data
logs: /Users/daniellerothermel/drotherm/logs
my_data: /Users/daniellerothermel/drotherm/data/dr_gen
my_logs: /Users/daniellerothermel/drotherm/logs/dr_gen
run_dir: /Users/daniellerothermel/drotherm/logs/dr_gen/bs500/lr0.1/wd0.0001/s0/2025-03-27/13-04-1743095078
dataset_cache_root: /Users/daniellerothermel/drotherm/data/cifar10/
agg_results: /Users/daniellerothermel/drotherm/data/dr_gen/cifar10/cluster_runs/lr_wd_init_v0



### Load, Disect and Filter Sweep

In [6]:
rg = RunGroup()
rg.load_runs_from_base_dir(cfg.paths.agg_results)

>> 0 / 1288 files failed parsing
>> Updated hpm sweep info


In [7]:
rg.ignore_runs_by_hpms(epochs=180)

>> Ignoring rid: 1287
>> Updated hpm sweep info


In [8]:
print(dsp.make_table(*rg.get_swept_table_data()))

+------+------------+
| Key  |   Values   |
+------+------------+
| Init | pretrained |
|      |   random   |
+------+------------+
|  WD  |  6.3e-05   |
|      |  0.00025   |
|      |   1e-05    |
|      |  0.00016   |
|      |   0.0001   |
|      |   4e-05    |
+------+------------+
|  LR  |    0.04    |
|      |    0.01    |
|      |    0.2     |
|      |    0.1     |
|      |    0.06    |
|      |    0.16    |
|      |    0.25    |
+------+------------+


In [9]:
table = dsp.make_table(*rg.get_hpms_sweep_table())
print(">> Current Sweep, Ready to Analyze:")
dsp.print_table(
    table,
    drop_cols=[],
    sort_cols=['Init', 'LR', 'WD'],
    lr=[0.04, 0.06, 0.1, 0.16, 0.25],
)

>> Current Sweep, Ready to Analyze:
+------------+------+---------+-------+
|    Init    |  LR  |    WD   | Count |
+------------+------+---------+-------+
| pretrained | 0.04 |  0.0001 |   20  |
| pretrained | 0.04 | 0.00016 |   20  |
| pretrained | 0.04 | 0.00025 |   20  |
| pretrained | 0.04 |  4e-05  |   20  |
| pretrained | 0.04 | 6.3e-05 |   20  |
| pretrained | 0.06 |  0.0001 |   20  |
| pretrained | 0.06 | 0.00016 |   20  |
| pretrained | 0.06 | 0.00025 |   20  |
| pretrained | 0.06 |  4e-05  |   20  |
| pretrained | 0.06 | 6.3e-05 |   20  |
| pretrained | 0.1  |  0.0001 |  103  |
| pretrained | 0.1  | 0.00016 |   20  |
| pretrained | 0.1  | 0.00025 |   20  |
| pretrained | 0.1  |  4e-05  |   20  |
| pretrained | 0.1  | 6.3e-05 |   20  |
| pretrained | 0.16 |  0.0001 |   20  |
| pretrained | 0.16 | 0.00016 |   20  |
| pretrained | 0.16 | 0.00025 |   20  |
| pretrained | 0.16 |  4e-05  |   20  |
| pretrained | 0.16 | 6.3e-05 |   20  |
| pretrained | 0.25 |  0.0001 |   20  |
| pr

In [10]:
runs_pre = rg.select_run_data_by_hpms(lr=0.1, wd=1e-4, init="pretrained")
for hpm, rlist in runs_pre.items():
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.0001           | 103 RIDS


In [12]:
runs_rand = rg.select_run_data_by_hpms(**{"optim.lr": 0.1, "optim.weight_decay": 1e-4, "init": "random"})
for hpm, rlist in runs_rand.items():
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=None optim.lr=0.1 optim.weight_decay=0.0001              | 99 RIDS


## Test Result Plotting

In [None]:
# TODO

In [15]:
for hpm, rlist in runs_rand.items():
    print(hpm)
    print(rlist[0])

model.weights=None optim.lr=0.1 optim.weight_decay=0.0001
<dr_gen.analyze.run_data.RunData object at 0x1225c3f50>


In [17]:
hpm_specs_one_each = rplt.make_hpm_specs()

In [21]:
hpms_pre, hpms_rand = rplt.get_pretrained_vs_random_init_runs(
    rg, hpm_specs_one_each, "val", one_per=True,
)

In [24]:
for hpm, rdata in hpms_pre.items():
    print(hpm, len(rdata), len(rdata[0]))

model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.0001 103 270


In [29]:
runs = rplt.select_runs_by_hpms({**hpms_pre, **hpms_rand}, hpms_rand)

In [40]:
runs_trimmed = rplt.trim_runs_metrics_dict(runs, nmax=12, tmax=22)

In [41]:
for hpm, rdata in runs_trimmed.items():
    print(hpm, len(rdata), len(rdata[0]), len(rdata[-1]))

model.weights=None optim.lr=0.1 optim.weight_decay=0.0001 12 22 22


In [49]:
runs_ndarrays_dict = rplt.runs_metrics_dict_to_ndarray_dict(runs_trimmed)

In [50]:
for hpm, rdata in runs_ndarrays_dict.items():
    print(hpm, rdata.shape)

model.weights=None optim.lr=0.1 optim.weight_decay=0.0001 (12, 22)


In [52]:
hpm, best_t = rplt.bootstrap_select_hpms(runs_ndarrays, early_stopping=False, num_bootstraps=None)

In [54]:
print(hpm, best_t)

model.weights=None optim.lr=0.1 optim.weight_decay=0.0001 21


In [59]:
one_hpm, one_rdata = next(iter(runs_ndarrays_dict.items()))

In [62]:
one_rdata[:, best_t]

array([75.83999634, 73.77999878, 72.58000183, 73.97000122, 77.43000031,
       76.72999573, 75.76999664, 73.79000092, 73.40999603, 75.36999512,
       75.31999969, 75.54000092])

In [72]:
one_rdata_notb = rplt.bootstrap_samples(one_rdata[:, best_t], b=None)
one_rdata_notb.shape, one_rdata_notb

((1, 12),
 array([[75.83999634, 73.77999878, 72.58000183, 73.97000122, 77.43000031,
         76.72999573, 75.76999664, 73.79000092, 73.40999603, 75.36999512,
         75.31999969, 75.54000092]]))

In [71]:
one_rdata_b1 = rplt.bootstrap_samples(one_rdata[:, best_t], b=1)
one_rdata_b1.shape, one_rdata_b1

((1, 12),
 array([[77.43000031, 73.77999878, 77.43000031, 75.36999512, 75.31999969,
         75.31999969, 73.40999603, 73.77999878, 73.77999878, 73.79000092,
         75.36999512, 75.36999512]]))

In [70]:
one_rdata_b2 = rplt.bootstrap_samples(one_rdata[:, best_t], b=2)
one_rdata_b2.shape, one_rdata_b2

((2, 12),
 array([[73.97000122, 73.97000122, 73.97000122, 73.79000092, 75.83999634,
         73.77999878, 75.36999512, 75.36999512, 75.83999634, 75.31999969,
         77.43000031, 73.79000092],
        [73.97000122, 75.54000092, 72.58000183, 73.79000092, 72.58000183,
         75.83999634, 75.83999634, 77.43000031, 76.72999573, 76.72999573,
         75.76999664, 73.40999603]]))

In [73]:
one_rdata_b2 = rplt.bootstrap_samples(one_rdata[:, best_t], b=2)
one_rdata_b2.shape, one_rdata_b2

((2, 12),
 array([[73.97000122, 75.76999664, 73.79000092, 75.54000092, 72.58000183,
         75.54000092, 75.83999634, 73.97000122, 76.72999573, 75.36999512,
         75.31999969, 77.43000031],
        [75.54000092, 77.43000031, 75.76999664, 77.43000031, 77.43000031,
         73.97000122, 77.43000031, 77.43000031, 73.40999603, 77.43000031,
         73.97000122, 75.31999969]]))

In [81]:
estim_one_rdata_b2 = rplt.bootstrap_summary_stats(one_rdata[:, best_t], b=1000)

In [82]:
estim_one_rdata_b2

{'dist': {'sorted_vals': array([[73.40999603, 73.40999603, 73.40999603, ..., 75.83999634,
          76.72999573, 77.43000031],
         [72.58000183, 72.58000183, 73.40999603, ..., 75.76999664,
          75.76999664, 75.76999664],
         [73.40999603, 73.79000092, 73.97000122, ..., 75.83999634,
          76.72999573, 77.43000031],
         ...,
         [72.58000183, 72.58000183, 73.40999603, ..., 75.76999664,
          76.72999573, 76.72999573],
         [72.58000183, 73.40999603, 73.77999878, ..., 75.76999664,
          75.83999634, 76.72999573],
         [73.77999878, 73.79000092, 73.97000122, ..., 76.72999573,
          76.72999573, 76.72999573]], shape=(1000, 12)),
  'n': array([12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
         12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
      