In [1]:
import sys
src_path = "/Users/kimathikaai/workspace/fond/"
if src_path not in sys.path:
    sys.path.insert(0, src_path)
print(sys.path)

['/Users/kimathikaai/workspace/fond/', '/Users/kimathikaai/workspace/fond/src/utils', '/opt/homebrew/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python311.zip', '/opt/homebrew/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11', '/opt/homebrew/Cellar/python@3.11/3.11.4_1/Frameworks/Python.framework/Versions/3.11/lib/python3.11/lib-dynload', '', '/Users/kimathikaai/workspace/envs/fond/lib/python3.11/site-packages']


In [2]:
import argparse

import wandb
import yaml
import copy
from tqdm import tqdm
import numpy as np

from src.train.fit import fit
from src.utils.hparams import random_hparams, seed_hash
from src.utils.run_info import get_project_runs, find_best_steps

%matplotlib inline
%load_ext autoreload
%autoreload 2

ENTITY = 'critical-ml-dg'
PROJECT_NAME = 'fond-checkpoints'
SWEEP_IDS = ['g8zq0tf1', "i4quzpc4", "5ewn6jqq"]
FILTERING_CRITERIA = {'kd_algo':'FOND'}
METRIC_NAME = 'val/nacc'
METRIC_GOAL = 'max'

In [3]:
# setup api
api = wandb.Api(
    overrides={
        "entity": ENTITY,
    }
)

# get runs
runs, unique_datasets = get_project_runs(
    api_conn=api,
    project_name=PROJECT_NAME,
    sweep_ids=SWEEP_IDS,
    filtering_criteria=FILTERING_CRITERIA,
)

[info] Extracting sweep information from the following sweeps: ['g8zq0tf1', 'i4quzpc4', '5ewn6jqq']


100%|██████████████████████████████████████████████████████| 120/120 [00:01<00:00, 72.39it/s]
100%|█████████████████████████████████████████████████████| 120/120 [00:01<00:00, 103.96it/s]
100%|██████████████████████████████████████████████████████| 120/120 [00:01<00:00, 95.82it/s]

[info] Runs to process: 240





In [4]:
# visualize a run
runs[0]

In [5]:
# visualize a run metric
list(runs[0].scan_history(keys=['val/nacc', 'step']))[-1]

{'val/nacc': 0.09854567050933838, 'step': 5000}

In [6]:
best_step_info = find_best_steps(
    runs=runs, metric_name=METRIC_NAME, metric_goal=METRIC_GOAL
)

100%|██████████████████████████████████████████████████████| 240/240 [10:13<00:00,  2.55s/it]


In [7]:
# Runs are organized based on dataset, overlap and test_id
INFO = {
    x:{
        y:{
            z: {'id': 0, 'step': 0, 'value': float('inf') if METRIC_GOAL=='min' else float('-inf')} for z in range(4)
        } for y in ['low', 'high', 'high_linked_only', 'low_linked_only']
    } 
    for x in ['PACS', 'VLCS', 'OfficeHome']
}
algorithm_info = {}
for run in tqdm(best_step_info):
    run_step_data = best_step_info[run]
    run_id = run.id
    run_dataset = run.config["dataset"]
    run_test_id = run.config["test_set_id"]
    run_overlap = run.config['overlap']
    run_algorithm = run.config['kd_algo']
    run_metric_value = run_step_data[METRIC_NAME]
    run_step = run_step_data['step_number']

    # update
    if run_algorithm not in algorithm_info:
        algorithm_info[run_algorithm] = copy.deepcopy(INFO)
        algorithm_info[run_algorithm][run_dataset][run_overlap][run_test_id] = {
            "id": run_id, 'value': run_metric_value, 'step':run_step}
        continue

    previous_val = algorithm_info[run_algorithm][run_dataset][run_overlap][run_test_id]['value']

    # compare based on goal
    if (METRIC_GOAL=="min") and (previous_val>run_metric_value):
        algorithm_info[run_algorithm][run_dataset][run_overlap][run_test_id] = {
            'id': run_id, 'value':run_metric_value, 'step': run_step}
    elif (METRIC_GOAL=="max") and (previous_val<run_metric_value):
        algorithm_info[run_algorithm][run_dataset][run_overlap][run_test_id] = {
            'id': run_id, 'value':run_metric_value, 'step': run_step}

algorithm_info    

100%|███████████████████████████████████████████████████| 240/240 [00:00<00:00, 91081.52it/s]


{'FOND': {'PACS': {'low': {0: {'id': 'xmqjyuzz',
     'value': 0.32841574649016064,
     'step': 2700},
    1: {'id': '7mx5skyi', 'value': 0.330560381213824, 'step': 4500},
    2: {'id': 'gybvfjta', 'value': 0.33100227018197376, 'step': 4200},
    3: {'id': 'oi7oxz3f', 'value': 0.3302321384350459, 'step': 3000}},
   'high': {0: {'id': 'lb4gwn9y', 'value': 0.3333333333333333, 'step': 900},
    1: {'id': '710c64zk', 'value': 0.3333333333333333, 'step': 600},
    2: {'id': 'simzmtfc', 'value': 0.3333333333333333, 'step': 900},
    3: {'id': '9hf5f6ag', 'value': 0.3311111132303874, 'step': 3000}},
   'high_linked_only': {0: {'id': 0, 'step': 0, 'value': -inf},
    1: {'id': 0, 'step': 0, 'value': -inf},
    2: {'id': 0, 'step': 0, 'value': -inf},
    3: {'id': 0, 'step': 0, 'value': -inf}},
   'low_linked_only': {0: {'id': 0, 'step': 0, 'value': -inf},
    1: {'id': 0, 'step': 0, 'value': -inf},
    2: {'id': 0, 'step': 0, 'value': -inf},
    3: {'id': 0, 'step': 0, 'value': -inf}}},
  'VL

In [8]:
# Find the corresponding test performance for a given step
# get the average METRIC_VALUE for each algorith, dataset and overlap
for algorithm in algorithm_info:
    for dataset in algorithm_info[algorithm]:
        for overlap in algorithm_info[algorithm][dataset]:
            for domain_id in algorithm_info[algorithm][dataset][overlap]:
                # Get the test values
            average = np.nanmean(list(algorithm_info[algorithm][dataset][overlap].values()))
            print(algorithm, dataset, overlap, f"{METRIC_NAME}: ", average)

TypeError: unsupported operand type(s) for +: 'dict' and 'dict'

In [None]:
# Report the test metric performance