# Continuous GFlowNets on a Simple 1D Line Environment

In [None]:
%reload_ext autoreload
%autoreload 2

from acquisition import *
from acquisition_env import *
from surrogate import *
from utils import *

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import trange

import warnings
warnings.filterwarnings('ignore')

In [None]:
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Helvetica",
    "figure.dpi": 400,
    "figure.figsize": (8, 4),
})
sns.set(style="whitegrid")

In [None]:
# Loop settings
problems = 20
max_evals = 15

# Data properties
# These represent the estimated min and max values of x
max_x = 15
min_x = -max_x
n_test_samples = 100

max_graph_x = 50
min_graph_x = -max_graph_x

plot_prefix = "gif_gfn/gif"
path_forward = "./models/acq_forward_10p_20240823.torch"
path_backward = "./models/acq_backward_10p_20240823.torch"

gp_kwargs = {
    "n_posterior_samples": 200,
    "x_bounds": None, # (min_x, max_x)
    "warmup_steps": 4, # 256 or 512
    "num_samples": 4, # 256 or 512
    "thinning": 1, # 32
    "noise_scale": 1e-12, # 1e-8
    "jit_compile": True, # True
    "show_plot": False,
    "disable_progbar": True,
}

gp_slow_kwargs = gp_kwargs.copy()  # Copy all values from gp_fast_kwargs
# Overwrite only the modified values
gp_slow_kwargs.update({
    "n_posterior_samples": 2000,
    "warmup_steps": 256,
    "num_samples": 256,
    "thinning": 32,
    "jit_compile": True,
    "show_plot": True
})

acq_kwargs = {
    "training_steps": 20, # 150
    "init_state_value": 0,
    "state_dim": 2,
	"hid_dim": 128,
	"lr_model": 1e-3,
	"lr_logz": 1e-2,
	"min_policy_std": 1e-2,
	"max_policy_std": 2.0,
    "init_explortation_noise": 1.0,
    "batch_size": 512,
    "inference_batch_size": 10_000,
    "trajectory_length": 10,
}

In [None]:
generator = DataGenerator()
acquisition = AcquisitionTrainer(**acq_kwargs)
# acquisition.load(path_forward, path_backward)

test_x = torch.linspace(min_graph_x, max_graph_x, n_test_samples).unsqueeze(-1)
graph_x = torch.linspace(min_graph_x, max_graph_x, gp_kwargs.get("n_posterior_samples", 1000)).unsqueeze(-1)

for problem_i in range(problems):
    oracle = generator.random_gaussian_mixture_pdf()
    # generator.plot(oracle, true_x)

    # Generate "graph" and "test" data
    graph_y = oracle(graph_x)
    test_y = oracle(test_x)

    # Generate initial training data point
    train_x = generator.random_nxm_tensor(1, 1, min_x, max_x)
    train_y = oracle(train_x)

    # Create acquisition trainer
    env = AcqEnvironment(max_evals, test_y)
    acquisition.set_env(env)

    # Create surrogate trainer and fit it on initial data
    surrogate = BayesianGPTrainer(train_x, train_y, graph_x, **gp_kwargs)
    surrogate_slow = BayesianGPTrainer(train_x, train_y, graph_x, **gp_slow_kwargs)
    surrogate.train()
    posterior = surrogate.get_posterior()
    gp_samples = surrogate.get_samples(posterior)
    acquisition.env.update_prev_posterior(gp_samples)

    for eval_i in range(1, max_evals):
        plot_prefix = f"./plots/gif_gfn_3/gfn_{problem_i:02d}_{eval_i:04d}"
        
        # Train acquisition model
        acquisition.train(surrogate, oracle, graph_x, graph_y, plot_prefix)

        # Get new point
        trajectory, new_x = acquisition.inference()
        new_x = new_x.mean(0).cpu()
        new_y = oracle(new_x)
        surrogate.add_data_point(new_x, new_y)
        surrogate_slow.add_data_point(new_x, new_y)

        # Plot the surrogate
        surrogate_slow.train()
        posterior = surrogate_slow.get_posterior()
        gp_samples = surrogate_slow.get_samples(posterior, mean=True)
        acquisition.env.update_prev_posterior(gp_samples)
        surrogate_slow.plot_gp(graph_x, graph_y, posterior, path=f"{plot_prefix}_final.png")


In [None]:
# acquisition.save(path_forward, path_backward)