In [195]:
%matplotlib inline

from hydra import initialize, compose
from omegaconf import OmegaConf 

import dr_gen.utils.run as ru
import dr_gen.utils.display as dsp

from dr_gen.analyze.run_group import RunGroup
import dr_gen.analyze.result_plotting as rplt

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Setup Config and Generator

In [2]:
with initialize(config_path="../configs/", version_base=None):
    cfg = compose(
        config_name="config.yaml",
        overrides=[
            "paths=mac",
        ]
    )

In [3]:
generator = ru.set_deterministic(cfg.seed)

In [4]:
print(OmegaConf.to_yaml(OmegaConf.to_container(cfg.paths, resolve=True)))

root: /Users/daniellerothermel/drotherm
proj_dir_name: dr_gen
data: /Users/daniellerothermel/drotherm/data
logs: /Users/daniellerothermel/drotherm/logs
my_data: /Users/daniellerothermel/drotherm/data/dr_gen
my_logs: /Users/daniellerothermel/drotherm/logs/dr_gen
run_dir: /Users/daniellerothermel/drotherm/logs/dr_gen/bs500/lr0.1/wd0.0001/s0/2025-04-03/10-22-1743690172
dataset_cache_root: /Users/daniellerothermel/drotherm/data/cifar10/
agg_results: /Users/daniellerothermel/drotherm/data/dr_gen/cifar10/cluster_runs/lr_wd_init_v0



### Load, Disect and Filter Sweep

In [5]:
rg = RunGroup()
rg.load_runs_from_base_dir(cfg.paths.agg_results)

>> 0 / 1288 files failed parsing
>> Updated hpm sweep info


In [6]:
rg.ignore_runs_by_hpms(epochs=180)

>> Ignoring rid: 1287
>> Updated hpm sweep info


In [7]:
print(dsp.make_table(*rg.get_swept_table_data()))

+------+------------+
| Key  |   Values   |
+------+------------+
| Init | pretrained |
|      |   random   |
+------+------------+
|  WD  |  6.3e-05   |
|      |   1e-05    |
|      |  0.00016   |
|      |   0.0001   |
|      |  0.00025   |
|      |   4e-05    |
+------+------------+
|  LR  |    0.16    |
|      |    0.04    |
|      |    0.25    |
|      |    0.2     |
|      |    0.06    |
|      |    0.1     |
|      |    0.01    |
+------+------------+


In [8]:
table = dsp.make_table(*rg.get_hpms_sweep_table())
print(">> Current Sweep, Ready to Analyze:")
dsp.print_table(
    table,
    drop_cols=[],
    sort_cols=['Init', 'LR', 'WD'],
    lr=[0.04, 0.06, 0.1, 0.16, 0.25],
)

>> Current Sweep, Ready to Analyze:
+------------+------+---------+-------+
|    Init    |  LR  |    WD   | Count |
+------------+------+---------+-------+
| pretrained | 0.04 |  0.0001 |   20  |
| pretrained | 0.04 | 0.00016 |   20  |
| pretrained | 0.04 | 0.00025 |   20  |
| pretrained | 0.04 |  4e-05  |   20  |
| pretrained | 0.04 | 6.3e-05 |   20  |
| pretrained | 0.06 |  0.0001 |   20  |
| pretrained | 0.06 | 0.00016 |   20  |
| pretrained | 0.06 | 0.00025 |   20  |
| pretrained | 0.06 |  4e-05  |   20  |
| pretrained | 0.06 | 6.3e-05 |   20  |
| pretrained | 0.1  |  0.0001 |  103  |
| pretrained | 0.1  | 0.00016 |   20  |
| pretrained | 0.1  | 0.00025 |   20  |
| pretrained | 0.1  |  4e-05  |   20  |
| pretrained | 0.1  | 6.3e-05 |   20  |
| pretrained | 0.16 |  0.0001 |   20  |
| pretrained | 0.16 | 0.00016 |   20  |
| pretrained | 0.16 | 0.00025 |   20  |
| pretrained | 0.16 |  4e-05  |   20  |
| pretrained | 0.16 | 6.3e-05 |   20  |
| pretrained | 0.25 |  0.0001 |   20  |
| pr

In [9]:
runs_pre = rg.select_run_data_by_hpms(lr=0.1, wd=1e-4, init="pretrained")
for hpm, rlist in runs_pre.items():
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.0001           | 103 RIDS


In [10]:
runs_rand = rg.select_run_data_by_hpms(**{"optim.lr": 0.1, "optim.weight_decay": 1e-4, "init": "random"})
for hpm, rlist in runs_rand.items():
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=None optim.lr=0.1 optim.weight_decay=0.0001              | 99 RIDS


In [184]:
runs_pre = rg.select_run_data_by_hpms(
    lr=0.04, init="pretrained",
)
for hpm, rlist in sorted([(str(hpm), rlist) for hpm, rlist in runs_pre.items()]):
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=0.0001          | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=0.00016         | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=0.00025         | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=4e-05           | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.04 optim.weight_decay=6.3e-05         | 20 RIDS


In [185]:
runs_pre = rg.select_run_data_by_hpms(
    lr=0.04, init="random",
)
for hpm, rlist in sorted([(str(hpm), rlist) for hpm, rlist in runs_pre.items()]):
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=None optim.lr=0.04 optim.weight_decay=0.0001             | 20 RIDS
 - model.weights=None optim.lr=0.04 optim.weight_decay=0.00016            | 20 RIDS
 - model.weights=None optim.lr=0.04 optim.weight_decay=0.00025            | 20 RIDS
 - model.weights=None optim.lr=0.04 optim.weight_decay=4e-05              | 20 RIDS
 - model.weights=None optim.lr=0.04 optim.weight_decay=6.3e-05            | 20 RIDS


In [182]:
runs_pre = rg.select_run_data_by_hpms(
    lr=0.1, init="pretrained",
)
for hpm, rlist in sorted([(str(hpm), rlist) for hpm, rlist in runs_pre.items()]):
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.0001           | 103 RIDS
 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.00016          | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.00025          | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=4e-05            | 20 RIDS
 - model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=6.3e-05          | 20 RIDS


In [183]:
runs_pre = rg.select_run_data_by_hpms(
    lr=0.1, init="random",
)
for hpm, rlist in sorted([(str(hpm), rlist) for hpm, rlist in runs_pre.items()]):
    print(f" - {str(hpm):70} | {len(rlist):,} RIDS")

 - model.weights=None optim.lr=0.1 optim.weight_decay=0.0001              | 99 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=0.00016             | 20 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=0.00025             | 20 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=1e-05               | 20 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=4e-05               | 20 RIDS
 - model.weights=None optim.lr=0.1 optim.weight_decay=6.3e-05             | 20 RIDS


## Test Result Plotting

In [196]:
hpm_specs_one_each = rplt.make_hpm_specs()

In [197]:
compare_stats_one_each = rplt.one_tn_no_hpm_select_compare_weight_init(
    rg, hpm_specs_one_each, 260, 80, num_bootstraps=1000, split="val",
)
#rplt.print_comparative_summary_stats(compare_stats_one_each)

In [199]:
rplt.print_comparative_summary_stats(compare_stats_one_each)

>> :: Per Dist Summary Stats ::

Compare using 1000 bootstraps:
  - [(80, 260) | best: 259] model.weights=DEFAULT optim.lr=0.1 optim.weight_decay=0.0001
  - [(80, 260) | best: 259] model.weights=None optim.lr=0.1 optim.weight_decay=0.0001

point
   n          | 80.00000000 | 80.00000000
   mean       | 85.11064845 | 84.42620394
   median     | 85.22565323 | 84.40210895
   min        | 82.01337959 | 82.93967748
   max        | 87.70918906 | 85.43803764
   variance   | 2.08370823 | 0.21561246
   std        | 1.44062701 | 0.46147093
   sem        | 0.16106700 | 0.05159402
   2.5th      | 82.30567447 | 83.58419981
   25th       | 84.08577075 | 84.15236169
   75th       | 86.24519642 | 84.74314260
   97.5th     | 87.33109846 | 85.29649194
   IQR        | 2.15942568 | 0.59078091
std
   n          | 0.00000000 | 0.00000000
   mean       | 0.16173210 | 0.05063997
   median     | 0.25405055 | 0.05847733
   min        | 0.17324780 | 0.50819271
   max        | 0.22364984 | 0.06906447
   variance 

In [200]:
hpm_specs_hpm_select = rplt.make_hpm_specs(
    #lr=[0.04, 0.06, 0.1, 0.16, 0.25],
    #wd=[1e-05, 4e-05, 6.3e-05, 0.0001, 0.00016, 0.00025],
    lr=[0.04, 0.1, 0.25],
    wd=[1e-05, 0.0001, 0.00025],
    epochs=270,
)
print(hpm_specs_hpm_select)

{'optim.lr': [0.04, 0.1, 0.25], 'optim.weight_decay': [1e-05, 0.0001, 0.00025], 'epochs': 270}


In [220]:
compare_stats_hpm_select = rplt.one_tn_hpm_compare_weight_init(
    rg, hpm_specs_hpm_select, 260, 80, num_bootstraps=1000,
)

TypeError: '<' not supported between instances of 'Hpm' and 'Hpm'

In [None]:
rplt.print_comparative_summary_stats(compare_stats_one_each)

In [192]:
sa = set(['a', 'b', 'c'])
sb = set(['a', 'd'])

In [193]:
sa & sb

{'a'}