# Experiment

In [None]:
#| default_exp ml.experiment

In [None]:
#| hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#| export

from pathlib import Path
import torch
import wandb
import json
from bellek.utils import context_chdir, NestedDict, flatten_dict

In [None]:
#| export

def make_experiment_dir(root="./experiments", name=None):
    if name is None:
        from bellek.utils import generate_time_id
        name = generate_time_id()
    experiment_dir = Path(root) / name
    experiment_dir.mkdir(parents=True, exist_ok=True)
    return experiment_dir

In [None]:
#| export

def prepare_config(config):
    if "device" not in config:
        config["device"] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    for k, v in config.flat().items():
        if isinstance(k, str) and k.endswith("path"):
            config.set(k, str(Path(v).resolve()))
    return config

In [None]:
#| export

def make_run_experiment_sweep(run_experiment, config_defaults):
    def func():
        wandb_params = config_defaults["wandb"]
        with wandb.init(config=flatten_dict(config_defaults), **wandb_params) as wandb_run:
            run_experiment(wandb_run)
    return func

def main(run_experiment, args):
    with open(args.cfg) as f:
        config = prepare_config(NestedDict(json.load(f)))

    if args.sweep_cfg:
        with open(args.sweep_cfg) as f:
            sweep_config = json.load(f)
    else:
        sweep_config = {}

    run_experiment_sweep = make_run_experiment_sweep(run_experiment, config)
    with context_chdir(make_experiment_dir()):
        wandb_params = config["wandb"]
        if args.sweep_cfg:
            count = sweep_config.pop("count")
            sweep_id = wandb.sweep(
                sweep_config,
                entity=wandb_params["entity"],
                project=wandb_params["project"],
            )
            wandb.agent(sweep_id, run_experiment_sweep, count=count)
        else:
            run_experiment_sweep()


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()