In [61]:
import jax
import jax.numpy as jnp
import chex
import optax
import matplotlib.pyplot as plt

from distributions import AnnealedDistribution, Target
from distributions.multivariate_gaussian import MultivariateGaussian
from utils.distributions import compute_distances

In [69]:
class TimeDependentLennardJonesEnergyButler(Target):
    TIME_DEPENDENT = True

    def __init__(
        self,
        dim: int,
        n_particles: int,
        alpha: float = 0.5,
        sigma: float = 1.0,
        sigma_cutoff: float = 2.5,
        epsilon_val: float = 1.0,
        min_dr: float = 1e-4,
        n: float = 1,
        m: float = 1,
        c: float = 0.5,
        log_prob_clip: float = None,
        score_norm: float = None,
    ):
        super().__init__(
            dim=dim,
            log_Z=None,
            can_sample=False,
            n_plots=10,
            n_model_samples_eval=1000,
            n_target_samples_eval=None,
        )
        self.n_particles = n_particles
        self.n_spatial_dim = dim // n_particles

        self.alpha = alpha
        self.sigma = sigma
        self.cutoff = sigma_cutoff * sigma

        self.epsilon_val = epsilon_val
        self.min_dr = min_dr
        self.n = n
        self.m = m
        self.c = c

        self.log_prob_clip = log_prob_clip
        self.score_norm = score_norm

    def soft_core_lennard_jones_potential(
        self,
        pairwise_dr: jnp.ndarray,
        t: float,
    ) -> jnp.ndarray:
        """
        Compute the time-dependent soft-core Lennard-Jones potential.

        Args:
            pairwise_dr (jnp.ndarray): Pairwise distances of shape [n_pairs].
            t (float): Time parameter, influencing the strength of the potential (lambda).

        Returns:
            jnp.ndarray: Time-dependent soft-core Lennard-Jones potential energy of shape [].
        """

        inv_r6 = (pairwise_dr / self.sigma) ** 6
        soft_core_term = self.alpha * (1 - t) ** self.m + inv_r6
        lj_energy = (
            self.epsilon_val * t**self.n * (soft_core_term**-2 - 2 * soft_core_term**-1)
        )

        print(lj_energy)
        # Apply cutoff: set energy to zero for distances > cutoff
        lj_energy = jnp.where(pairwise_dr <= self.cutoff, lj_energy, 0.0)

        print(lj_energy)
        # Sum over all pairs to get total energy per sample
        total_lj_energy = jnp.sum(lj_energy, axis=-1)

        return total_lj_energy

    def harmonic_potential(self, x):
        """
        Compute the harmonic potential energy.

        E^osc(x) = 1/2 * Σ ||xi - x_COM||^2
        """
        x = x.reshape(self.n_particles, self.n_spatial_dim)
        x_com = jnp.mean(x, axis=0)
        distances_to_com = optax.safe_norm(
            x - x_com,
            axis=-1,
            min_norm=0.0,
        )

        return 0.5 * jnp.sum(distances_to_com**2)

    def compute_time_dependent_lj_energy(
        self,
        x: jnp.ndarray,
        t: float,
    ) -> jnp.ndarray:
        """
        Compute the total time-dependent soft-core Lennard-Jones energy for a batch of samples.

        Args:
            x (jnp.ndarray): Input array of shape [n_particles * n_spatial_dim].
            t (float): Time parameter.

        Returns:
            jnp.ndarray: Total time-dependent Lennard-Jones energy.
        """
        pairwise_dr = compute_distances(
            x,
            n_particles=self.n_particles,
            n_dimensions=self.n_spatial_dim,
            min_dr=self.min_dr,
        )
        lj_energy = self.soft_core_lennard_jones_potential(pairwise_dr, t)

        if self.log_prob_clip is not None:
            lj_energy = jnp.clip(lj_energy, -self.log_prob_clip, self.log_prob_clip)

        harmonic_energy = self.harmonic_potential(x)

        return lj_energy + self.c * harmonic_energy

    def log_prob(self, x: chex.Array) -> chex.Array:
        return -self.compute_time_dependent_lj_energy(x, 1.0)

    def time_dependent_log_prob(self, x: chex.Array, t: float) -> chex.Array:
        p_t = -self.compute_time_dependent_lj_energy(x, t)
        return p_t

    def score(self, x: chex.Array, t: float) -> chex.Array:
        sc = jax.grad(self.time_dependent_log_prob, argnums=0)(x, t)

        if self.score_norm is not None:
            norm = optax.safe_norm(sc, axis=-1, min_norm=1e-6)
            scale = jnp.clip(self.score_norm / (norm + 1e-6), a_min=0.0, a_max=1.0)

            return sc * scale
        else:
            return sc

    def sample(
        self, key: jax.random.PRNGKey, sample_shape: chex.Shape = ()
    ) -> chex.Array:
        raise NotImplementedError(
            "Sampling is not implemented for TimeDependentLennardJonesEnergy"
        )

    def interatomic_dist(self, x):
        x = x.reshape(-1, self.n_particles, self.n_spatial_dim)
        distances = jax.vmap(lambda x: compute_distances(x))(x)

        return distances

    def batched_log_prob(self, xs, t):
        return jax.vmap(self.time_dependent_log_prob, in_axes=(0, None))(xs, t)

    def visualise(self, samples: chex.Array) -> plt.Figure:
        # Fill samples nan values with zeros
        samples = jnp.nan_to_num(samples, nan=0.0, posinf=1.0, neginf=-1.0)

        # Since we don't have a test set, we will just visualize the samples
        fig, axs = plt.subplots(1, 2, figsize=(12, 4))

        dist_samples = self.interatomic_dist(samples)

        axs[0].hist(
            dist_samples.flatten(),
            bins=100,
            alpha=0.5,
            density=True,
            histtype="step",
            linewidth=4,
        )
        axs[0].set_xlabel("Interatomic distance")
        axs[0].legend(["Generated data"])

        energy_samples = -self.batched_log_prob(samples, 1.0)
        # Clip energy values for visualization
        energy_samples = jnp.nan_to_num(
            energy_samples, nan=0.0, posinf=100.0, neginf=-100.0
        )

        # Determine histogram range from cleaned data
        min_energy = jnp.min(energy_samples)
        max_energy = jnp.max(energy_samples)

        # Add padding to range
        energy_range = (
            min_energy - 0.1 * abs(min_energy),
            max_energy + 0.1 * abs(max_energy),
        )

        axs[1].hist(
            energy_samples,
            bins=100,
            density=True,
            alpha=0.4,
            range=energy_range,
            color="r",
            histtype="step",
            linewidth=4,
            label="Generated data",
        )
        axs[1].set_xlabel("Energy")
        axs[1].legend()

        fig.canvas.draw()
        return fig


In [70]:
jax.config.update("jax_platform_name", "cpu")

key = jax.random.PRNGKey(12391)

initial_density = MultivariateGaussian(dim=2, mean=0.0, sigma=2.0)
target_density = TimeDependentLennardJonesEnergyButler(
    dim=2,
    n_particles=2,
    alpha=0.2,
    sigma=1.0,
    epsilon_val=1.0,
    min_dr=1e-6,
    n=1,
    m=1,
    c=0.5,
)

path_density = AnnealedDistribution(
    initial_density=initial_density, target_density=target_density
)


In [71]:
target_density.time_dependent_log_prob(jnp.array([[0.0, 2.4]]), 1.)

[-0.01043818]
[-0.01043818]


Array(-0.7095618, dtype=float32)

In [72]:
from utils.distributions import compute_distances

In [75]:
dist = compute_distances(jnp.array([[0.0, 2.6]]), n_particles=2, n_dimensions=1)

In [76]:
target_density.soft_core_lennard_jones_potential(dist, 1.0)

[-0.00646378]
[0.]


Array(0., dtype=float32)