In [94]:
%matplotlib inline

from hydra import initialize, compose
from omegaconf import OmegaConf 

import dr_gen.utils.run as ru
import dr_gen.utils.display as dsp

from dr_gen.analyze.run_group import RunGroup
import dr_gen.analyze.result_plotting as rplt

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Setup Config and Generator

In [95]:
with initialize(config_path="../configs/", version_base=None):
    cfg = compose(
        config_name="config.yaml",
        overrides=[
            "paths=mac",
        ]
    )

In [96]:
generator = ru.set_deterministic(cfg.seed)

In [97]:
print(OmegaConf.to_yaml(OmegaConf.to_container(cfg.paths, resolve=True)))

root: /Users/daniellerothermel/drotherm
proj_dir_name: dr_gen
data: /Users/daniellerothermel/drotherm/data
logs: /Users/daniellerothermel/drotherm/logs
my_data: /Users/daniellerothermel/drotherm/data/dr_gen
my_logs: /Users/daniellerothermel/drotherm/logs/dr_gen
run_dir: /Users/daniellerothermel/drotherm/logs/dr_gen/bs500/lr0.1/wd0.0001/s0/2025-03-05/18-57-1741222623
dataset_cache_root: /Users/daniellerothermel/drotherm/data/cifar10/
agg_results: /Users/daniellerothermel/drotherm/data/dr_gen/cifar10/cluster_runs/lr_wd_init_v0



### Load, Disect and Filter Sweep

In [217]:
rg = RunGroup()
rg.load_runs_from_base_dir(cfg.paths.agg_results)

>> 0 / 1288 files failed parsing
>> Updated hpm sweep info


In [218]:
rg.ignore_runs_by_hpms(epochs=180)

>> Ignoring rid: 1287
>> Updated hpm sweep info


In [219]:
print(dsp.make_table(*rg.get_swept_table_data()))

+------+------------+
| Key  |   Values   |
+------+------------+
| Init |   random   |
|      | pretrained |
+------+------------+
|  WD  |  6.3e-05   |
|      |  0.00025   |
|      |  0.00016   |
|      |   0.0001   |
|      |   4e-05    |
|      |   1e-05    |
+------+------------+
|  LR  |    0.06    |
|      |    0.16    |
|      |    0.2     |
|      |    0.25    |
|      |    0.04    |
|      |    0.01    |
|      |    0.1     |
+------+------------+


In [227]:
table = dsp.make_table(*rg.get_hpms_sweep_table())
dsp.print_table(
    table,
    drop_cols=[],
    sort_cols=['Init', 'LR', 'WD'],
    lr=[0.1, 0.01],
    epochs=270,
)

+------------+------+---------+-------+
|    Init    |  LR  |    WD   | Count |
+------------+------+---------+-------+
| pretrained | 0.01 |  0.0001 |   20  |
| pretrained | 0.01 |  1e-05  |   5   |
| pretrained | 0.1  |  0.0001 |  103  |
| pretrained | 0.1  | 0.00016 |   20  |
| pretrained | 0.1  | 0.00025 |   20  |
| pretrained | 0.1  |  4e-05  |   20  |
| pretrained | 0.1  | 6.3e-05 |   20  |
|   random   | 0.01 |  0.0001 |   20  |
|   random   | 0.01 |  1e-05  |   20  |
|   random   | 0.1  |  0.0001 |   99  |
|   random   | 0.1  | 0.00016 |   20  |
|   random   | 0.1  | 0.00025 |   20  |
|   random   | 0.1  |  1e-05  |   20  |
|   random   | 0.1  |  4e-05  |   20  |
|   random   | 0.1  | 6.3e-05 |   20  |
+------------+------+---------+-------+


### Update rplt

### Util Usage

#### Loading and Parsing Run Logs

In [None]:
# Example: Parsing all runs, dropping error runs
print(f">> Parsing {len(run_logs):,} Runs")
good_runs, bad_runs = rp.parse_run_logs(run_logs)
print(f"   - num good: {len(good_runs):,}")
print(f"   - num w/ errors: {len(bad_runs):,}")

#### Determining the Sweeps in the Data

In [None]:
# Example: Getting Sweep Info & Simple Display
sweep_info = rp.extract_sweeps(good_runs)
swept_table = dsp.make_table(*rp.get_swept_table(sweep_info['swept_vals']))
print(f"Swept Values:\n{swept_table}")

In [None]:
# Example: Show the Sweep Combo Run Info
combo_table = dsp.make_table(*rp.get_combo_table_contents(
    sweep_info['combo_key_order'],
    sweep_info['combo_inds'],
))

print(">> First 5 lines of combos")
print(combo_table.get_string(start=0, end=5))

print("\n>> Just Default HPMs")
def filter_function(vals):
    return (
        #vals[0] == "None" and   # Init
        vals[1] == "0.1" and     # LR
        vals[2] == "0.0001" and  # WD
        vals[4] == "270"         # Epochs
    )
print(combo_table.get_string(row_filter=filter_function))

In [None]:
# Example: Get inds for specific configuration
# Here selected_inds is {(combo_keys): [inds, ...]}
selected_inds = rp.get_inds_by_kvs(
    sweep_info['combo_key_order'], sweep_info['combo_inds'],
    {
        'epochs': '270',
        'model.weights': 'None',
        'optim.lr': '0.1',
    },
)
dsp.make_table(
    [*sweep_info['combo_key_order'], "Count"],
    [[*k, len(selected_inds[k])] for k in selected_inds.keys()],
)

#### Format and Extract Metrics

In [None]:
all_runs_metrics = rp.remap_run_list_metrics(good_runs)
print(f">> Splits: {list(all_runs_metrics.keys())}")
print(f">> Metrics: {list(all_runs_metrics['train'].keys())}")

In [None]:
# Example: Get specific metric for a single run
rp.get_run_metrics(all_runs_metrics, "train", "loss", 0)[:5]

In [None]:
# Example: Get specific metric for list of inds
two_mets = rp.get_runs_metrics(all_runs_metrics, "train", "loss", [0,1])
print(">> Loss for two runs")
dsp.make_table(['Run 0', 'Run 1'], [[two_mets[0][i], two_mets[1][i]] for i in range(4)])

In [None]:
# Example: Get inds for specific configuration
# Here selected_inds is {(combo_keys): [inds, ...]}
selected_inds = rp.get_inds_by_kvs(
    sweep_info['combo_key_order'], sweep_info['combo_inds'],
    {
        'epochs': '270',
        #'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    },
)
print(dsp.make_table(
    [*sweep_info['combo_key_order'], "Count"],
    [[*k, len(selected_inds[k])] for k in selected_inds.keys()],
))
selected_metrics = rp.get_selected_run_metrics(
    all_runs_metrics,
    "train",
    "loss",
    selected_inds,
)
for k, v in selected_metrics.items():
    print(f">> Selected values: {k}")
    print(f"  - Num Runs: {len(v)}")
    print(f"  - Epochs Per Run: {len(v[0])}")
    print(f"  - First few metric vals: {v[0][:6]}")
    print()

#### Basic Plotting Utils

In [None]:
# Example: Plot multiple splits for same run
pu.plot_lines(
    pu.get_plt_cfg(
        ylabel="acc", title="Accuracy During Training",
        ylim=(70, 100), labels=['train', 'val']
    ),
    [
        rp.get_run_metrics(all_runs_metrics, "train", "acc1", 0),
        rp.get_run_metrics(all_runs_metrics, "val", "acc1", 0),
    ],
)

In [None]:
# Example: Plot one split for multiple inds (manually selected)
pu.plot_lines(
    pu.get_plt_cfg(
        ylabel="acc", title="Val Accuracy During Training",
        ylim=(70, 86), labels=[0,1,2],
    ),
    rp.get_runs_metrics(all_runs_metrics, "val", "acc1", [0,1,2]),
)

In [None]:
# Example: Plot one split for multiple inds (selected via combo)
selected_inds = rp.get_inds_by_kvs(
    sweep_info['combo_key_order'], sweep_info['combo_inds'],
    {
        'epochs': '270',
        'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    },
)
selected_metrics = list(rp.get_selected_run_metrics(
    all_runs_metrics,
    "val",
    "acc1",
    selected_inds,
).items())
met_keys = [sm[0] for sm in selected_metrics]
met_values = [sm[1] for sm in selected_metrics]
pu.plot_lines(
    pu.get_plt_cfg(
        ylabel="acc",
        title=rplt.kvs_to_str(
            list(zip(sweep_info['combo_key_order'], met_keys[0]))
        ),
        ylim=(70, 86),
        legend=False,
        #labels=[str(k) for k in selected_metrics],
    ),
    met_values[0],
)

#### Simple Histogram and Ind Data

In [None]:
# Example of getting aggregate stats and vals for a single ind
all_kvs, all_split_vals, _ = rp.get_selected_combo(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    splits=['train', 'val'],
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=5,
)
rd_stats_ind100 = pu.get_runs_data_stats_ind(
    all_split_vals['val'][0],
    ind=100,
)
rd_stats_ind100

In [None]:
all_kvs, all_split_vals, _ = rp.get_selected_combo(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    splits=['train', 'val'],
    ignore_keys=["optim.step_size", "epochs"],
)
rd_stats_ind100 = pu.get_runs_data_stats_ind(
    all_split_vals['val'][0],
    ind=100,
)
hp.plot_histogram(
    pu.get_plt_cfg(
        nbins=rd_stats_ind100['n']//4,
        hist_range=(80, 90),
        title=f"Accuracy Distribution, {rd_stats_ind100['n']} Seeds",
        ylabel="Num Runs",
    ),
    rd_stats_ind100['vals'],    
)

In [None]:
all_kvs, all_split_vals, _ = rp.get_selected_combo(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        #'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    splits=['train', 'val'],
    ignore_keys=["optim.step_size", "epochs"],
)
stats1_ind100 = pu.get_runs_data_stats_ind(
    all_split_vals['val'][0],
    ind=100
)
stats2_ind100 = pu.get_runs_data_stats_ind( 
    all_split_vals['val'][1],
    ind=100,
)
labels_kvstrs = [
    rplt.kvs_to_str(
        [(k, v) for k, v in kvs if k == "model.weights"]
    ) for kvs in all_kvs
]
title_kvstr = rplt.kvs_to_str(
    [(k, v) for k, v in all_kvs[0] if k != "model.weights"]
)
ns = [stats1_ind100['n'], stats2_ind100['n']]
hp.plot_histogram_compare(
    pu.get_plt_cfg(
        nbins=max(ns)//4,
        hist_range=(80, 90),
        title=f"Accuracy Distribution | {title_kvstr}",
        ylabel="Num Runs",
        labels=[
            rplt.kvs_to_str(
                [(k, v) for k, v in kvs if k == "model.weights"]
            ) + f" n={ns[i]} " for i, kvs in enumerate(all_kvs)
        ],
    ),
    [
        stats1_ind100,
        stats2_ind100,
    ],
)

In [None]:
all_kvs, all_split_vals, _ = rp.get_selected_combo(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        #'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    splits=['train', 'val'],
    ignore_keys=["optim.step_size", "epochs"],
)
stats1_ind100 = pu.get_runs_data_stats_ind(
    all_split_vals['val'][0],
    ind=100
)
stats2_ind100 = pu.get_runs_data_stats_ind( 
    all_split_vals['val'][1],
    ind=100,
)
results = ks.calculate_ks_for_run_sets(
    stats1_ind100['vals'],
    stats2_ind100['vals'],
)
pu.plot_cdf(
    pu.get_plt_cfg(),
    results['all_vals'],
    results['cdf1'],
)
pu.plot_cdfs(
    pu.get_plt_cfg(
        labels=[
            rplt.kvs_to_str(
                [(k, v) for k, v in kvs if k == "model.weights"]
            ) + f" n={ns[i]} " for i, kvs in enumerate(all_kvs)
        ],
    ),
    results['all_vals'],
    [
        results['cdf1'],
        results['cdf2'],
    ]
)

#### Using Result Plotting Utils

In [None]:
rplt.plot_run_splits(
    good_runs,
    all_runs_metrics,
    sweep_info,
    run_ind=0,
    ignore_keys=['optim.step_size', 'epochs'],
    ylim=(75, 100),
)

In [None]:
rplt.plot_split_summaries(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    splits=['train', 'val'],
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=None,
    ylim=(82, 86),
)
    

In [None]:
rplt.plot_split_summaries(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    splits=['train', 'val'],
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=20,
    ylim=(82, 86),
)
    

In [None]:
rplt.plot_combo_histogram(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    split='val',
    epoch=110,
    metric='acc1',
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=None,
    #nbins=10,
    hist_range=(82,86),
    density=True,
)

In [None]:
rplt.plot_combo_histogram_compare(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        #'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    split="val",
    epoch=100,
    metric="acc1",
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=20,
    nbins=10,
    vary_key="model.weights",
)

In [None]:
rplt.ks_stats_plot_cdfs(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        #'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    split="val",
    epoch=100,
    metric="acc1",
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=None,
    vary_key="model.weights",
)

In [None]:
rplt.ks_stats_plot_cdfs(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        'model.weights': 'None',
        #'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    split="val",
    epoch=100,
    metric="acc1",
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=None,
    vary_key="optim.lr",
    vary_vals=[str(v) for v in [0.1, 0.01]],
)

In [None]:
rplt.ks_stat_plot_cdfs_histograms(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        #'model.weights': 'None',
        'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    split="val",
    epoch=100,
    metric="acc1",
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=None,
    vary_key="model.weights",
    nbins=40,
)

In [None]:
rplt.ks_stat_plot_cdfs_histograms(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        'model.weights': 'None',
        #'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    split="val",
    epoch=100,
    metric="acc1",
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=20,
    vary_key="optim.lr",
    vary_vals=[str(lr) for lr in [0.1, 0.01]],
    nbins=40,
)

In [None]:
rplt.ks_stat_plot_cdfs_histograms(
    good_runs,
    all_runs_metrics,
    sweep_info,
    kv_select={
        'epochs': '270',
        'model.weights': 'None',
        #'optim.lr': '0.1',
        'optim.weight_decay': '0.0001',
    }, 
    split="val",
    epoch=100,
    metric="acc1",
    ignore_keys=["optim.step_size", "epochs"],
    num_seeds=20,
    vary_key="optim.lr",
    vary_vals=[str(lr) for lr in [0.1, 0.04]],
    nbins=40,
)