In [None]:
import os
os.chdir("/home/guest_dyw/diffusion-sampler/")

import torch
import matplotlib.pyplot as plt

from omegaconf import DictConfig

from tqdm import tqdm

from trainer.utils.langevin import langevin_dynamics
from energy import (
    AnnealedDensities,
    AnnealedEnergy,
    BaseEnergy,
    get_energy_function,
    GaussianEnergy,
)
from utility import SamplePlotter

In [None]:
cfg = DictConfig(
    {
        "num_samples": 1000,
        "num_time_steps": 3,
        "max_iter_ls": 1000,
        "burn_in": 100,
        "ld_schedule": True,
        "ld_step": 0.1,
        "target_acceptance_rate": 0.574,
        "device": "cuda",
        "energy": {
            "_target_": "energy.gmm.GMM25",
            "dim": 2,
        },
        "eval": {
            "plot": {
                "plotting_bounds": [-15.0, 15.0],
                # "projection_dims": [[0, 2], [1, 2], [2, 4], [3, 4], [4, 6], [5, 6]],
                "fig_size": [12, 12],
            }
        },
    }
)

In [None]:
energy = get_energy_function(cfg)
prior = GaussianEnergy(device="cuda", dim=2, std=12.0)
plotter = SamplePlotter(energy, **cfg.eval.plot)

annealed_densities = AnnealedDensities(energy, prior)

device = cfg.device
num_time_steps = cfg.num_time_steps
num_samples = cfg.num_samples

In [None]:
sample = prior.sample(num_samples, device)

for t in torch.linspace(0, 1, num_time_steps):
    print(t)
    annealed_energy = AnnealedEnergy(annealed_densities, t)
    sample, _ = langevin_dynamics(sample, annealed_energy.log_reward, device, cfg)
    sample = sample.detach()[-num_samples:]

    fig, ax = plotter.make_sample_plot(sample)
    plt.show()