In [1]:
import wandb
import os
import pickle


def get_history(user="kealexanderwang", project="constrained-pnns", query={},
                **kwargs):
    api = wandb.Api()
    runs = api.runs(path=f"{user}/{project}", filters=query)
    dataframes = [run.history(**kwargs) for run in runs]
    return list(zip(runs, dataframes))


def download_files(user="kealexanderwang", project="constrained-pnns",
                   query={}, save_dir=".", **kwargs):
    """
    Download the files of each run into a new directory for the run.
    Also saves the config dict of the run.
    """
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    api = wandb.Api()
    runs = api.runs(path=f"{user}/{project}", filters=query)
    for run in runs:
        name = run.name
        config = run.config

        run_dir = os.path.join(save_dir, name)
        if not os.path.isdir(run_dir):
            os.mkdir(run_dir)

        with open(os.path.join(run_dir, "config.pkl"), "wb") as h:
            pickle.dump(config, h)

        files = run.files()
        for file in files:
            file.download(root=run_dir)
    return

In [4]:
from pytorch_lightning import Trainer
from pl_trainer import DynamicsModel, SaveTestLogCallback
import os
import pprint 

def load_model_from_run(run, save_dir="/tmp"):
    name = run.display_name
    ckpt_save_path = os.path.join(save_dir, name)
    if not os.path.exists(ckpt_save_path):
        os.makedirs(ckpt_save_path)
     
    ckpts = sorted([f for f in run.files() if "checkpoints" in f.name])
    if len(ckpts) == 0:
        raise RuntimeError(f"Run {name} has no checkpoints!")
    # pick latest checkpoint if available
    last_ckpt = ckpts[-1]
    last_ckpt.download(replace=True, root=ckpt_save_path)
        
    ckpt_path = os.path.join(ckpt_save_path, last_ckpt.name)
    # Uncommet if you need the trainer
    # trainer = Trainer(resume_from_checkpoint=ckpt_path,logger=False)
    pl_trainer = None
    pl_model = DynamicsModel.load_from_checkpoint(ckpt_path)

    pp = pprint.PrettyPrinter(indent=4)
    print("Model Hyperparameters:")
    pp.pprint(vars(pl_model.hparams))
    return pl_trainer, pl_model

In [5]:
# See https://docs.wandb.com/library/reference/wandb_api for how to write queries
query = {"tags": {"$eq": "3pendulum"}}
runs, histories = zip(*get_history(query=query))
pl_trainer, pl_model = load_model_from_run(runs[1])

tensor(5.1746e-06)
NN ignores wgrad
NN currently assumes time independent ODE
Model Hyperparameters:
{   'angular_dims': range(0, 3),
    'batch_size': 800,
    'body_args': [3],
    'body_class': 'ChainPendulum',
    'callbacks': [   <pytorch_lightning.callbacks.lr_logger.LearningRateLogger object at 0x7f0ffc0b1f10>,
                     <pl_trainer.SaveTestLogCallback object at 0x7f0ffc0429d0>,
                     <pytorch_lightning.callbacks.progress.ProgressBar object at 0x7f0ffc042a50>],
    'check_val_every_n_epoch': 100,
    'chunk_len': 5,
    'ckpt_dir': '/home/alex_w/repos/hamiltonian-biases/experiments/ChainPendulumn3m1l1/NN/wandb/run-20200505_204111-2l6n8uf7/constrained-pnns/version_2l6n8uf7/checkpoints',
    'dataset_class': 'RigidBodyDataset',
    'debug': False,
    'dof_ndim': 3,
    'dt': 0.1,
    'euclidean': False,
    'exp_dir': '/home/alex_w/repos/hamiltonian-biases/experiments/ChainPendulumn3m1l1/NN',
    'fast_dev_run': False,
    'gpus': 1,
    'hidden_size': 2

