Skip to content

Commit

Permalink
feat: add molboil, finilize experiment configs (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
lollcat committed Aug 1, 2023
1 parent 7d8aa46 commit 6e324cd
Show file tree
Hide file tree
Showing 104 changed files with 5,083 additions and 1,300 deletions.
6 changes: 3 additions & 3 deletions .gitignore
@@ -1,8 +1,8 @@
*__pychache*
*/.ipynb_checkpoints
/target/data/qm9_valid.npy
/target/data/qm9_train.npy
/target/data/qm9_test.npy
*/target/data/qm9_valid.npy
*/target/data/qm9_train.npy
*/target/data/qm9_test.npy
*/temp/*
*.pyc
*.npz
Expand Down
27 changes: 16 additions & 11 deletions README.md
Expand Up @@ -2,24 +2,29 @@


# Install
At time of publishing was using jax 0.4.13 with python 3.10.
```
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
```
Has dependency on pytorch (NB: use CPU version).
```
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
```
For the alanine dipeptide problem we need to install openmmtools with conda:
`conda install -c conda-forge openmm openmmtools`
```
conda install -c conda-forge openmm openmmtools
```
Finally then,
```
pip install -e .
```

## Experiments
```shell
python examples/dw4.py
python examples/lj13.py
python examples/aldp.py
python examples/qm9_download_data.py
python examples/qm9.py
python examples/dw4_fab.py
python examples/lj13_fab.py
```

## Instalation

1. Install packages in requirements.txt
2. Run

```
pip install -e .
```
24 changes: 16 additions & 8 deletions examples/analyse_results/dw4_results/plot.py
@@ -1,6 +1,6 @@
import jax.random
from omegaconf import DictConfig
import yaml
from hydra import compose, initialize
import hydra

from examples.load_flow_and_checkpoint import load_flow
from examples.default_plotter import *
Expand All @@ -25,7 +25,7 @@ def make_get_data_for_plotting(
max_n_samples: int = 10000,
plotting_n_nodes: Optional[int] = None,
max_distance: Optional[float] = 20.,
): # Override default plotter
): # Override default.yaml plotter
bins_x, count_list = bin_samples_by_dist([train_data.positions[:max_n_samples],
test_data.positions[:max_n_samples]], max_distance=max_distance)
n_samples = n_samples_from_flow
Expand Down Expand Up @@ -53,12 +53,19 @@ def get_data_for_plotting(state: TrainingState, key: chex.PRNGKey, train_data=tr
return get_data_for_plotting, count_list, bins_x


_BASE_DIR = '../../..'


def plot_dw4(ax: Optional = None):
download_checkpoint(flow_type='spherical', tags=["dw4", "final_run"], seed=0, max_iter=200,
base_path='./examples/dw4_results/models')
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=f"{_BASE_DIR}/examples/config/")
cfg = compose(config_name="dw4.yaml")

download_checkpoint(flow_type='spherical', tags=["dw4", "ml", "florence"], seed=0, max_iter=100,
base_path='./examples/analyse_results/dw4_results/models')

checkpoint_path = "examples/analyse_results/dw4_results/models/spherical_seed0.pkl"

checkpoint_path = "examples/dw4_results/models/spherical_seed0.pkl"
cfg = DictConfig(yaml.safe_load(open(f"examples/config/dw4.yaml")))
n_samples_from_flow_plotting = 1000
key = jax.random.PRNGKey(0)

Expand Down Expand Up @@ -94,4 +101,5 @@ def plot_dw4(ax: Optional = None):


if __name__ == '__main__':
plot_dw4()
# Should be run from repo base directory to work.
plot_dw4()
144 changes: 144 additions & 0 deletions examples/analyse_results/evaluate_checkpoints_fab.py
@@ -0,0 +1,144 @@
import jax.random
from hydra import compose, initialize
import hydra
import pathlib

from examples.create_train_config import setup_logger
from examples.load_flow_and_checkpoint import load_flow
from examples.analyse_results.get_wandb_runs import download_checkpoint
from examples.dw4_fab import load_dataset_original as load_ds_dw4
from examples.dw4_fab import log_prob_fn as log_prob_fn_dw4
from examples.lj13_fab import load_dataset as load_ds_lj13
from examples.lj13_fab import log_prob_fn as log_prob_fn_lj13


from typing import Union

import chex
import jax
from functools import partial

from molboil.train.base import eval_fn
from train.max_lik_train_and_eval import get_eval_on_test_batch

from fabjax.sampling import build_smc, build_blackjax_hmc, build_metropolis
from train.fab_train_no_buffer import TrainStateNoBuffer
from train.fab_train_with_buffer import TrainStateWithBuffer
from train.fab_eval import fab_eval_function

_BASE_DIR = '../..'

problems = ["dw4_fab", "lj13_fab"]
hydra_configs = ["dw4_fab.yaml", "lj13_fab.yaml"]


def get_setup_info(problem: str):
tags = ["post_sub", "cblgpu", "fab"]
hydra_config = problem + ".yaml"
if problem == "dw4_fab":
tags.append("dw4")
max_iter = 20000
else:
assert problem == "lj13_fab"
tags.append("lj13")
max_iter = 14000
return tags, hydra_config, max_iter


def setup_eval(cfg, flow, target_log_p_x_fn, test_data):
use_hmc = cfg.fab.use_hmc
n_intermediate_distributions = cfg.fab.n_intermediate_distributions
spacing_type = cfg.fab.spacing_type

# Setup training functions.
dim_total = int(flow.dim_x*(flow.n_augmented+1)*test_data.features.shape[-2])
if use_hmc:
transition_operator = build_blackjax_hmc(dim=dim_total, **cfg.fab.transition_operator.hmc)
else:
transition_operator = build_metropolis(dim_total, **cfg.fab.transition_operator.metropolis)

eval_on_test_batch_fn = partial(get_eval_on_test_batch,
flow=flow, K=cfg.training.K_marginal_log_lik, test_invariances=True)

# AIS with p as the target. Note that step size params will have been tuned for alpha=2.
smc_eval = build_smc(transition_operator=transition_operator,
n_intermediate_distributions=n_intermediate_distributions, spacing_type=spacing_type,
alpha=1., use_resampling=False)

def evaluation_fn(state: Union[TrainStateNoBuffer, TrainStateWithBuffer], key: chex.PRNGKey) -> dict:
eval_info = eval_fn(test_data, key, state.params,
eval_on_test_batch_fn=eval_on_test_batch_fn,
eval_batch_free_fn=None,
batch_size=cfg.training.eval_batch_size)
eval_info_fab = fab_eval_function(
state=state, key=key, flow=flow,
smc=smc_eval,
log_p_x=target_log_p_x_fn,
features=test_data.features[0],
batch_size=cfg.fab.eval_total_batch_size,
inner_batch_size=cfg.fab.eval_inner_batch_size
)
eval_info.update(eval_info_fab)
return eval_info
return evaluation_fn


def download_checkpoint_and_eval(problem, seed, flow_type):
print(f"evaluating checkpoint for {problem} {seed} {flow_type}")

# Setup
tags, hydra_config, max_iter = get_setup_info(problem)

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=f"{_BASE_DIR}/examples/config/")
cfg = compose(config_name=hydra_config)
cfg.flow.type = flow_type
cfg.training.seed = seed

load_dataset = load_ds_dw4 if problem == "dw4_fab" else load_ds_lj13
target_log_p_x_fn = log_prob_fn_dw4 if problem == "dw4_fab" else log_prob_fn_lj13

base_dir = f'./examples/analyse_results/{hydra_config[:-4]}/models'
pathlib.Path(base_dir).mkdir(parents=True, exist_ok=True)

# Download checkpoint from WANDB.
download_checkpoint(flow_type=flow_type, tags=tags, seed=seed, max_iter=max_iter,
base_path=base_dir)
print("checkpoint downloaded")

checkpoint_path = f"examples/analyse_results/{hydra_config[:-4]}/models/{flow_type}_seed{seed}.pkl"

flow, state = load_flow(cfg, checkpoint_path)
if problem == "lj13_fab":
state = jax.tree_map(lambda x: x[0], state) # for lj13 we used multiple devices.
print("loaded checkpoint from first device")
else:
print("loaded checkpoint (single device)")

debug = False
if debug:
cfg.training.test_set_size = 10
cfg.training.eval_model_samples = 100
cfg.training.eval_batch_size = 10
cfg.training.K_marginal_log_lik = 2

train_data, test_data = load_dataset(cfg.training.train_set_size, cfg.training.test_set_size)

eval_fn = setup_eval(cfg=cfg, flow=flow, target_log_p_x_fn=target_log_p_x_fn, test_data=test_data)

key = jax.random.PRNGKey(0)

cfg.logger.wandb.tags = [problem, "evaluation", "eval_2"]
cfg.logger.wandb.name = problem + "_evaluation"
logger = setup_logger(cfg)
info = eval_fn(state, key)
logger.write(info)
print(info)
logger.close()
print("evaluation complete")


if __name__ == '__main__':
for flow_type in ["along_vector", "spherical", "proj", "non_equivariant"]:
for seed in [0, 1, 2]:
download_checkpoint_and_eval(problem="lj13_fab", seed=seed, flow_type=flow_type)
72 changes: 46 additions & 26 deletions examples/analyse_results/fab_pull_results_and_populate_table.py
Expand Up @@ -3,15 +3,21 @@
from examples.analyse_results.get_wandb_runs import get_run_history


_TAGS = ['final_run', 'fab']
_TAGS_DW4 = ['post_sub1','cblgpu', 'fab', "dw4"]
_TAGS_Lj13_eval = ["lj13_fab", "evaluation", "eval_2"]
_TAGS_LJ13_run = ['post_sub','cblgpu', 'fab', "lj13"]

def download_eval_metrics(problem="dw4",
def download_eval_metrics(tags,
n_runs=3,
flow_types=('spherical', 'along_vector', 'proj', 'non_equivariant'),
step_number=-1):
seeds = [0, 1, 2, 3, 4]
tags = _TAGS.copy()
tags.append(problem)
step_number=-1,
allow_single_step: bool = True,
fields=('marginal_log_lik', 'lower_bound_marginal_gap',
'eval_ess_flow', 'eval_ess_ais',
'_runtime',
"_step")):
fields = list(fields)
seeds = [0, 1, 2]
data = pd.DataFrame()

i = 0
Expand All @@ -23,11 +29,12 @@ def download_eval_metrics(problem="dw4",
iter_n = step_number
for seed in seeds:
try:
hist = get_run_history(flow_type, tags, seed, fields=['marginal_log_lik', 'lower_bound_marginal_gap',
'eval_ess_flow', 'eval_ess_ais', '_runtime',
"_step"])
info = dict(hist.iloc[iter_n])
if info["_step"] == 0:
hist = get_run_history(flow_type, tags, seed, fields=fields)
if iter_n == -1:
info = dict(hist.iloc[iter_n])
else:
info = dict(hist[hist["_step"] == iter_n].iloc[0])
if info["_step"] == 0 and not allow_single_step:
print(f"skipping {flow_type} seed={seed} as it only has 1 step")
continue
info.update(flow_type=flow_type, seed=seed)
Expand All @@ -37,24 +44,30 @@ def download_eval_metrics(problem="dw4",
if n_runs_found == n_runs:
break
except:
pass
# print(f"No runs for for flow_type {flow_type}, tags {tags} seed {seed} found!")
print(f"No runs for for flow_type {flow_type}, tags {tags} seed {seed} found!")

if n_runs_found != n_runs:
print(f"Less than {n_runs} runs found for flow {flow_type}")
print(f"Less than {n_runs} runs found for flow {flow_type} tags {tags}")
return data.T


def create_latex_table():
step_numbers_dw4 = [2,3,3,3]
step_numbers_lj13 = [10,6,6,6]
flow_types = ['non_equivariant', 'along_vector', 'proj', 'spherical'] #
row_names = ['\\' + "noneanf", "\\vecproj \ \eanf", "\\cartproj \ \eanf",
"\\sphproj \ \eanf"]
keys = ['eval_ess_flow', 'eval_ess_ais', 'marginal_log_lik', 'lower_bound_marginal_gap', 'runtime']
keys = ['eval_ess_flow', 'eval_ess_ais', 'marginal_log_lik', 'lower_bound_marginal_gap', '_runtime']
eval_step_number_dw4 = [64008, -1, -1, -1]

data_dw4 = download_eval_metrics("dw4", flow_types=flow_types, step_number=step_numbers_dw4)
data_lj13 = download_eval_metrics("lj13", flow_types=flow_types, step_number=step_numbers_lj13)

data_dw4 = download_eval_metrics(_TAGS_DW4, flow_types=flow_types, step_number=eval_step_number_dw4)

data_lj13 = download_eval_metrics(_TAGS_Lj13_eval, flow_types=flow_types, step_number=-1,
fields = ('marginal_log_lik', 'lower_bound_marginal_gap',
'eval_ess_flow', 'eval_ess_ais',
"_step"))
lj13_runtime = download_eval_metrics(_TAGS_LJ13_run, flow_types=flow_types, step_number=-1,
fields=('_runtime',"_step"))
data_lj13["_runtime"] = lj13_runtime["_runtime"]



Expand Down Expand Up @@ -96,23 +109,30 @@ def create_latex_table():
f"{-means_lj13.loc[flow_type]['marginal_log_lik']:.2f},{sem_lj13.loc[flow_type]['marginal_log_lik']:.2f} \\\ \n"
# f"0,0 & 0,0 & 0,0 \\\ \n"

ess_table_string += f"{row_names[i]} & " \
f"{means_dw4.loc[flow_type]['eval_ess_ais'] * 100:.2f},{sem_dw4.loc[flow_type]['eval_ess_ais'] * 100:.2f} \\\ \n"


runtime_table_string += \
f"{row_names[i]} & " \
f"{means_dw4.loc[flow_type]['_runtime']/3600:.1f},{sem_dw4.loc[flow_type]['_runtime']/3600:.1f} & " \
f"{means_lj13.loc[flow_type]['_runtime']/3600:.1f},{sem_lj13.loc[flow_type]['_runtime']/3600:.1f} \\\ \n "


table_lower_bound_gap += \
f"{row_names[i]} & " \
f"{means_dw4.loc[flow_type]['lower_bound_marginal_gap']:.2f},{sem_dw4.loc[flow_type]['lower_bound_marginal_gap']:.2f} & " \
f"{-means_lj13.loc[flow_type]['lower_bound_marginal_gap']:.2f},{sem_lj13.loc[flow_type]['lower_bound_marginal_gap']:.2f} \\\ \n "
# table_lower_bound_gap += \
# f"{row_names[i]} & " \
# f"{means_dw4.loc[flow_type]['lower_bound_marginal_gap']:.2f},{sem_dw4.loc[flow_type]['lower_bound_marginal_gap']:.2f} & " \
# f"{-means_lj13.loc[flow_type]['lower_bound_marginal_gap']:.2f},{sem_lj13.loc[flow_type]['lower_bound_marginal_gap']:.2f} \\\ \n "
# f"0,0 \\\ \n"

# print(table_values_string)
# print("\n\n")
print("main table results")
print(table_v2_string)
print("\n\n")
print("appendix ess table results")
print(ess_table_string)
# print(table_lower_bound_gap)
# print("\n\n")
print("\n\n")
print("runtime table")
print(runtime_table_string)


Expand Down

0 comments on commit 6e324cd

Please sign in to comment.