In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"  
import jax
from jax import numpy as np
from jax import vmap
import optax
import numpy as onp
from typing import Callable, Tuple
import sys
import argparse
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85'


# Network

In [None]:
import haiku as hk
import jax.numpy as np
import numpy as onp
from typing import Optional, Tuple, Callable, Union
import jax
from jax import vmap
from dataclasses import dataclass

###################################################################################################################
def construct_mlp_layers(
    n_hidden: int,
    n_neurons: int, 
    act: Callable[[np.ndarray], np.ndarray], 
    output_dim: int,
    residual_blocks: bool = True 
) -> list:
    """Make a list containing the layers of an MLP.

    Args:
        n_hidden: Number of hidden layers in the MLP.
        n_neurons: Number of neurons per hidden layer.
        act: Activation function.
        output_dim: Dimension of the output.
        residual_blocks: Whether or not to use residual blocks.
    """
    layers = [] 

    resid_act = lambda x: x + act(x)
    for layer in range(n_hidden):
        ## construct layer
        if layer == 0 or not residual_blocks:
            layers = layers + [
                    hk.Linear(n_neurons),
                    act
                ]
        else:
            layers = layers + [
                    hk.Linear(n_neurons),
                    resid_act
                ]


    ## construct output layer
    layers = layers + [hk.Linear(output_dim)] 
    return layers




###################################################################################################################
def construct_score_network(
    d: int,
    n_hidden: int,
    n_neurons: int,
    act: Callable[[np.ndarray], np.ndarray],
    residual_blocks: bool = True,
    is_gradient: bool = False
) -> Tuple[Callable, Callable]:
    """Construct a score network for a simpler system
    that does not consist of interacting particles.

    Args:
        d: System dimension.
        n_hidden: Number of hidden layers in the network.
        n_neurons: Number of neurons per layer.
        act: Activation function.
        residual_blocks: Whether or not to use residual blocks.
        is_gradient: Whether or not to compute the score as the gradient of a potential.
    """
    output_dim = 1 if is_gradient else d 
    net = lambda x: hk.Sequential(
        construct_mlp_layers(n_hidden, n_neurons, act, output_dim, residual_blocks)
    )(x)

    if is_gradient:
        potential = lambda x: np.squeeze(net(x)) 
        score = jax.grad(potential)
        return hk.without_apply_rng(hk.transform(score)), \
                hk.without_apply_rng(hk.transform(potential))
    else:
        return hk.without_apply_rng(hk.transform(net)), None


# Losses

In [None]:
import jax
from jax.flatten_util import ravel_pytree
import jax.numpy as np
from typing import Tuple, Callable, Union, Optional
from jax import jit, vmap, value_and_grad
from jaxopt.linear_solve import solve_cg
from functools import partial
import haiku as hk
import optax


#######################################################################################################################
def grad_log_rho0(
    sample: np.ndarray,
    sig0: float,
    mu0: np.ndarray
) -> np.ndarray:
    """Compute the initial potential. Assumed to be an isotropic Gaussian."""
    return -(sample - mu0) / sig0**2

def rho0(
    sample: np.ndarray,
    sig0: float = 1,
    mu0: np.ndarray = 0
) -> np.ndarray:
    """Compute the 
    ial potential. Assumed to be an isotropic Gaussian."""
    sig0 = 1
    mu0 = 0
    return np.exp(-(sample - mu0)**2 / (2 * sig0**2))



#######################################################################################################################
@partial(jit, static_argnums=(4, 5, 7))
def B_init_loss(
    B_params: np.ndarray,
    samples: np.ndarray,
    sig0: float,
    mu0: np.ndarray,
    B_apply_score: Callable[[hk.Params, np.ndarray, Optional[float]], np.ndarray],
    time_dependent: bool = False, 
    frame_end: float = 0,
    nt: int = 0
) -> np.ndarray:
    """Compute the initial loss, assuming access to \nabla \log \rho0."""

    grad_log_rho_evals = vmap(
        lambda sample:  D * grad_log_rho0(sample, sig0, mu0)
    )(samples) 
    
    score_evals = vmap(
        lambda sample:   B_apply_score(B_params, sample)
    )(samples)

    score = np.sum(   ( score_evals - (grad_log_rho_evals) )**2) \
            / np.sum((grad_log_rho_evals)**2)

    return score

#######################################################################################################################
@partial(jit, static_argnums=(4, 5, 7))
def L_init_loss(
    L_params: np.ndarray,
    samples: np.ndarray,
    sig0: float,
    mu0: np.ndarray,
    L_apply_score: Callable[[hk.Params, np.ndarray, Optional[float]], np.ndarray],
    time_dependent: bool = False, 
    frame_end: float = 0,
    nt: int = 0
) -> np.ndarray:
    """Compute the initial loss, assuming access to \nabla \log \rho0."""
    
    score_evals = vmap(
        lambda sample:   L_apply_score(L_params, sample)
    )(samples)

    @jax.jit
    def levypart(sample):
        sample = np.array(sample)   
        lambdaj_ri_matrix = lambdaj[:, None] * ri[None, :] 
        lambdaj_ri_matrix = np.expand_dims(lambdaj_ri_matrix, axis=2)
        expanded_samples = sample - lambdaj_ri_matrix 
        all_scores = (jax.vmap(jax.vmap(rho0))(expanded_samples)).squeeze()
        scores_sumj = np.mean(all_scores, axis=0)
        return np.mean(scores_sumj * ri * N_ri * 7 * sig_jump) / rho0(sample)

    LEVYpart = vmap(levypart, in_axes=0)(samples)
    score = np.sum(   ( score_evals - ( - LEVYpart) )**2) \
            / np.sum((-LEVYpart)**2)

    return score

  



#######################################################################################################################
@jit
def compute_grad_norm(
    grads: hk.Params
) -> float:
    """ Computes the norm of the gradient, where the gradient is input
    as an hk.Params object (treated as a PyTree)."""
    flat_params = ravel_pytree(grads)[0]
    return np.linalg.norm(flat_params) / np.sqrt(flat_params.size)


#######################################################################################################################
@partial(jit, static_argnums=(2, 3))
def update(
    params: hk.Params, 
    opt_state: optax.OptState, 
    opt: optax.GradientTransformation,
    loss_func: Callable[[hk.Params], float],
    loss_func_args: Tuple = tuple(), 
) -> Tuple[hk.Params, optax.OptState, float, hk.Params]:
    """Update the neural network.

    Args:
        params: Parameters to optimize over.
        opt_state: State of the optimizer.
        opt: Optimizer itself.
        loss_func: Loss function for the parameters.
    """
    loss_value, grads = value_and_grad(loss_func)(params, *loss_func_args)
    
    updates, opt_state = opt.update(grads, opt_state, params=params) 
    new_params = optax.apply_updates(params, updates) 
    return new_params, opt_state, loss_value, grads


# Updata

In [None]:
from jax import jit, vmap
import jax
from functools import partial
import jax.numpy as np
import numpy as onp
import haiku as hk
from typing import Callable, Tuple, Union

#######################################################################################################################
@partial(jit, static_argnums=(6,7,8)) 
def update_particles(
    particle_pos: np.ndarray,
    t: float,
    B_params: hk.Params,
    L_params: hk.Params,
    D: Union[np.ndarray, float], 
    dt: float,
    forcing: Callable[[np.ndarray, float], np.ndarray], 
    B_apply_score: Callable[[hk.Params, np.ndarray], np.ndarray], 
    L_apply_score: Callable[[hk.Params, np.ndarray], np.ndarray],
    mask: np.ndarray = None
) -> np.ndarray:
    """Take a forward Euler step and update the particles."""
    if mask is not None:
        score_term = -mask * B_apply_score(B_params, particle_pos) - mask * L_apply_score(L_params, particle_pos)
    else:
        score_term = -B_apply_score(B_params, particle_pos) - L_apply_score(L_params, particle_pos)

    return particle_pos + dt*(forcing(particle_pos, t) + score_term)



#######################################################################################################################
@partial(jit, static_argnums=(5, 6))
def update_particles_EM(
    particle_pos: np.ndarray,
    t: float,
    D_sqrt: Union[np.ndarray, float], 
    dt: float,
    key: np.ndarray, 
    forcing: Callable[[np.ndarray, float], np.ndarray],
    noisy: bool = True,
    mask: np.ndarray = None,
) -> np.ndarray:
    """Take a step forward via Euler-Maruyama to update the particles."""

    if noisy:
        noise = np.sqrt(2*dt) * jax.random.normal(key, shape=particle_pos.shape)
        if mask is not None:
            brownian = -D_sqrt * mask * noise
        else:
            brownian = -D_sqrt * noise

        return particle_pos + dt*forcing(particle_pos, t) + brownian
    else:
        return particle_pos + dt*forcing(particle_pos, t)



#######################################################################################################################
def rollout_EM_trajs(
    x0s: np.ndarray,
    nsteps: int,
    t0: float,
    dt: float,
    key: np.ndarray,
    forcing: Callable[[np.ndarray, float], np.ndarray],
    D_sqrt: Union[np.ndarray, float],
    noisy: bool = True
) -> np.ndarray:
    """Given a set of initial conditions, create a stochastic trajectory 
    via Euler-Maruyama. Useful for constructing a baseline against which to compare
    the moments.

    Args:
    ------
    x0s: Initial condition. Dimension = n x d where n is the number of samples 
         and d is the dimension of the system.
    nsteps: Number of steps of Euler-Maruyama to take.
    t0: initial time.
    dt: Timestep.
    key: jax PRNG key.
    forcing: Forcing to apply to the particles.
    D_sqrt: Square root of the diffusion matrix.
    """
    n, d = x0s.shape
    trajs = onp.zeros((nsteps+1, n, d)) 
    trajs[0] = x0s
    step_sample = \
            lambda sample, t, key: update_particles_EM(sample, t, D_sqrt, 
                                                       dt, key, forcing, noisy)
    step_samples = vmap(step_sample, in_axes=(0, None, 0), out_axes=0)

    
    for curr_step in tqdm(range(nsteps)):
        t = t0 + curr_step*dt
        keys = jax.random.split(key, num=n) 
        trajs[curr_step+1] = step_samples(trajs[curr_step], t, keys)
        key = keys[-1] 

    return trajs, key 


# Rollouts

In [None]:
from jax import jit, vmap
from jax.tree_util import tree_map
import jax
from functools import partial
import jax.numpy as np
import numpy as onp
import haiku as hk
from typing import Callable, Tuple, Union
import optax
import dill as pickle
import time
from tqdm import tqdm
from jaxlib.xla_extension import Device

Time = float


def fit_initial_condition(
    n_max_opt_steps: int, 
    ltol: float, 
    B_params: hk.Params, 
    L_params: hk.Params,
    sig0: float,
    mu0: np.ndarray,
    B_score_network: Callable[[hk.Params, np.ndarray], np.ndarray],
    L_score_network: Callable[[hk.Params, np.ndarray], np.ndarray],
    B_opt: optax.GradientTransformation,
    B_opt_state: optax.OptState,
    L_opt: optax.GradientTransformation,
    L_opt_state: optax.OptState,
    samples: np.ndarray,
    time_dependent: bool = False,
    frame_end: float = 0,
    nt: int = 0
) -> hk.Params:
    """Fit the score for the initial condition.

    Args:
        n_opt_steps: Number of optimization steps before the norm of the gradient 
                     is checked.
        gtol: Tolerance on the norm of the gradient.
        ltol: Tolerance on the relative error.
        params: Parameters to optimize over.
        sig0: Standard deviation of the target initial condition.
        mu0: Mean of the target initial condition.
        score_network: Function mapping parameters and a sample to the network output.
        opt: Optimizer.
        opt_state: State of the optimizer.
        samples: Samples to optimizer over.
    """
    
    B_apply_score = B_score_network.apply 
    B_loss_func = lambda B_params: \
            B_init_loss(B_params, samples, sig0, mu0, B_apply_score, 
                             time_dependent, frame_end, nt)

    L_apply_score = L_score_network.apply 
    L_loss_func = lambda L_params: \
            L_init_loss(L_params, samples, sig0, mu0, L_apply_score, 
                             time_dependent, frame_end, nt)
            
    B_loss_val = np.inf
    L_loss_val = np.inf
    try:
        with tqdm(range(n_max_opt_steps)) as pbar:
            pbar.set_description("Initial optimization")
            for curr_step in pbar:
                try:
                    B_params, B_opt_state, B_loss_val, B_grads = update(B_params, B_opt_state, B_opt, B_loss_func)
                    L_params, L_opt_state, L_loss_val, L_grads = update(L_params, L_opt_state, L_opt, L_loss_func)
                    pbar.set_postfix(B_loss=B_loss_val, L_loss=L_loss_val)
                    if (B_loss_val < ltol) and (L_loss_val < ltol):
                        break
                except Exception as e:
                    print(f"Error at step {curr_step}: {e}")
                    break
    except Exception as e:
        print(f"Error initializing the progress bar: {e}")

    return B_params, L_params


# Drifts


In [None]:
from jax import vmap
from jax.lax import stop_gradient
import jax.numpy as np
from typing import Callable, Tuple



############################################################################################################################
def example_1(
    x: np.ndarray,
    t: float,
    # gamma: float
    V0: float,
    Gamma: float,
    L: float,
) -> np.ndarray:
    del t
    return -(V0/Gamma) * (2*np.pi/L * np.cos(2*np.pi*x/L ) + 1*np.pi/L * np.cos(4*np.pi*x/L ))



# SBTM-SIM


In [None]:
from dataclasses import dataclass
from typing import Callable, Tuple, Union
import haiku as hk
import jax
import numpy as onp
from jaxlib.xla_extension import Device
import optax


State = onp.ndarray
Time = float


@dataclass
class SBTMSim:
    """
    Base class for all SBTM simulations.
    Contains simulation parameters common to all SBTM approaches.
    """
    # initial condition fitting
    n_max_init_opt_steps: int
    init_learning_rate: float
    init_ltol: float
    sig0: float
    mu0: onp.ndarray

    # system parameters
    drift: Callable[[State, Time], State]
    force_args: Tuple
    amp: Callable[[Time], float]
    freq: float
    dt: float
    D: onp.ndarray
    D_sqrt: onp.ndarray
    n: int
    d: int
    N: int

    # timestepping
    ltol: float
    gtol: float
    n_opt_steps: int
    learning_rate: float

    # network parameters
    n_hidden: int
    n_neurons: int
    act: Callable[[State], State]
    residual_blocks: bool
    interacting_particle_system: bool

    # general simulation parameters
    key: onp.ndarray
    B_params_list: list
    L_params_list: list
    all_samples: dict

    # output information
    output_folder: str
    output_name: str


    def __init__(self, data_dict: dict) -> None:
        self.__dict__ = data_dict.copy()

    def initialize_forcing(self) -> None:
        self.forcing = lambda x, t: self.drift(x, t, *self.force_args)

    def initialize_network_and_optimizer(self) -> None:
        """Initialize the network parameters and optimizer."""

        self.B_score_network, self.potential_network= \
            construct_score_network(
                self.d,
                self.n_hidden,
                self.n_neurons,
                self.act,
                is_gradient=False
            )

        self.L_score_network, self.potential_network= \
            construct_score_network(
                self.d,
                self.n_hidden,
                self.n_neurons,
                self.act,
                is_gradient=False
            )
      
        example_x = onp.zeros(self.d) 
        self.key, sk = jax.random.split(self.key) 
        
        B_init_params = self.B_score_network.init(self.key, example_x) 
        self.B_params_list = [B_init_params]   
        network_size = jax.flatten_util.ravel_pytree(B_init_params)[0].size 
        print(f'Number of parameters (B): {network_size}')
        print(f'Number of parameters needed for overparameterization (B): ' \
                + f'{self.n*example_x.size}')
        # set up the optimizer
        self.B_opt = optax.radam(self.learning_rate) 
        self.B_opt_state = self.B_opt.init(B_init_params) 
        # set up batching for the score
        self.B_batch_score = jax.vmap(self.B_score_network.apply, in_axes=(None, 0))
        
        L_init_params = self.L_score_network.init(self.key, example_x) 
        self.L_params_list = [L_init_params]   
        network_size = jax.flatten_util.ravel_pytree(L_init_params)[0].size 
        print(f'Number of parameters (L): {network_size}')
        print(f'Number of parameters needed for overparameterization (L): ' \
                + f'{self.n*example_x.size}')
        # set up the optimizer
        self.L_opt = optax.radam(self.learning_rate) 
        self.L_opt_state = self.L_opt.init(L_init_params) 
        # set up batching for the score
        self.L_batch_score = jax.vmap(self.L_score_network.apply, in_axes=(None, 0))

    def fit_init(self, cpu: Device, gpu: Device) -> None:
        """Fit the initial condition."""
        # draw samples
        samples_shape = (self.n, self.N*self.d) 
        init_samples = self.sig0*onp.random.randn(*samples_shape) + self.mu0[None, :] 
        
        # set up optimizer
        B_init_params = jax.device_put(self.B_params_list[0], gpu) 
        B_opt = optax.adabelief(self.init_learning_rate) 
        B_opt_state = B_opt.init(B_init_params)

        L_init_params = jax.device_put(self.L_params_list[0], gpu) 
        L_opt = optax.adabelief(self.init_learning_rate) 
        L_opt_state = L_opt.init(L_init_params)
        
        B_init_params, L_init_params = fit_initial_condition(
                            self.n_max_init_opt_steps,
                            self.init_ltol,
                            B_init_params,
                            L_init_params,
                            self.sig0,
                            self.mu0,
                            self.B_score_network,
                            self.L_score_network,
                            B_opt,
                            B_opt_state,
                            L_opt,
                            L_opt_state,
                            init_samples
                        )
        self.B_params_list = [jax.device_put(B_init_params, device=cpu)] 
        self.L_params_list = [jax.device_put(L_init_params, device=cpu)] 
        self.all_samples = {'SDE': [init_samples], 'learned': [init_samples]}


In [None]:
###### Entropy Calculation #######
@partial(jit, static_argnums=(5, 6, 7, 8, 9))
def compute_sample_entropy_rate(
    sample: np.ndarray,
    t: float,
    B_params: hk.Params,
    L_params: hk.Params,
    D: Union[np.ndarray, float],
    forcing: Callable[[State, Time], State],
    B_score_network: Callable[[hk.Params, State], State],
    L_score_network: Callable[[hk.Params, State], State],
    noise_free: bool,
    div: bool,
) -> float:
    B_st = B_score_network.apply(B_params, sample)
    L_st = B_score_network.apply(L_params, sample)
    vt = forcing(sample, t) - B_st - L_st
    return np.sum(vt*vt), np.sum(forcing(sample, t) * vt), np.sum(- L_st * vt), np.sum(- B_st * vt)
    
@partial(jit, static_argnums=( 5, 6, 7, 8, 9))
def compute_entropy_rate(
    samples: np.ndarray, 
    t: float, 
    B_params: hk.Params, 
    L_params: hk.Params, 
    D: np.ndarray,
    forcing: Callable[[State, Time], State],
    B_score_network: Callable,
    L_score_network: Callable,
    noise_free: bool,
    div: bool
) -> float:
    Nones = (None,)*9
    results = vmap(
        compute_sample_entropy_rate, in_axes=(0, *Nones)
    )(samples, t, B_params, L_params, D, forcing, B_score_network, L_score_network, noise_free, div)

    eprtot = np.mean(results[0])
    eprm = np.mean(results[1])
    epract = np.mean(results[2])
    eprsys = np.mean(results[3])
    return eprtot, eprm, epract, eprsys

# SBTM-SEQUENTIAL

In [None]:
from dataclasses import dataclass
import jax
import jax.numpy as np
from jax import vmap
import numpy as onp
import dill as pickle
import time
from jaxlib.xla_extension import Device
from haiku import Params
from typing import Tuple


@dataclass
class SequentialSBTM(SBTMSim): 
    n_time_steps: int
    use_SDE: bool
    use_ODE: bool
    save_fac: int 
    store_fac: int 
    means: dict
    covs: dict 
    epr_tot: list 
    epr_m: list 
    epr_act: list 
    epr_sys: list 
    mask: np.ndarray


    def setup_loss(self):
        """Define the loss function. """
        raise NotImplementedError("Please implement in the inheriting class.") 


    
    def setup_loss_fn_args(self, gpu: Device) -> Tuple:
        """Define the arguments to the loss function other than parameters."""
        raise NotImplementedError("Please implement in the inheriting class.")
        

    def setup_batched_steppers(self):
        """Construct convenience functions to step the particles."""
        self.step_learned = vmap(
                lambda B_params, L_params, t, sample: update_particles(
                    sample, 
                    t, 
                    B_params, 
                    L_params,
                    self.D, 
                    self.dt,  
                    self.forcing, 
                    self.B_score_network.apply,
                    self.L_score_network.apply,
                    self.mask
                ),
                in_axes=(None, None, None, 0),
                out_axes=0
        )

        self.step_SDE = vmap(
                lambda t, sample, key: update_particles_EM(
                    sample, 
                    t, 
                    self.D_sqrt, 
                    self.dt, 
                    key,  
                    self.forcing,
                    True,
                    self.mask
                ),
                in_axes=(None, 0, 0),
                out_axes=0
        )

    def step_samples(
        self,
        step,
        B_params: Params,
        L_params: Params,
        t: float,
        samples: np.ndarray,
        SDE_samples: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Step and save both the SDE and ODE samples."""
        # step learned and SDE particles
        samples = self.step_learned(B_params, L_params, t, samples)
        keys = jax.random.split(self.key, num=self.n)
        SDE_samples = self.step_SDE(t, SDE_samples, keys)
        self.key = keys[-1]

        # save new samples
        if (step+1) % self.store_fac == 0:
            self.all_samples['learned'].append(onp.array(samples))
            self.all_samples['SDE'].append(onp.array(SDE_samples))

        return samples, SDE_samples


    def setup_learning_samples(
        self,
        samples: np.ndarray,
        SDE_samples: np.ndarray
    ) -> Tuple[np.ndarray]:
        """Set up samples to optimize over."""
        if self.use_ODE and self.use_SDE:
            opt_samples = (np.vstack((samples, SDE_samples)),)
        elif self.use_ODE:
            opt_samples = (samples,)
        elif self.use_SDE:
            opt_samples = (SDE_samples,)
        else:
            raise ValueError('Need to specify learning from ODE or SDE.')

        return opt_samples
    
    def compute_epr(
        self,
        B_params: Params,
        L_params: Params,
        t: float,
        samples: np.ndarray, 
        SDE_samples: np.ndarray
    ) -> None:
        ## entropy
        eprtot, eprm, epract, eprsys = compute_entropy_rate(
            samples, t, B_params, L_params, self.D, self.forcing, 
            self.B_score_network, self.L_score_network, noise_free=False, div=True
        )
        self.epr_tot.append(eprtot)
        self.epr_m.append(eprm)
        self.epr_act.append(epract)
        self.epr_sys.append(eprsys)


    def solve_fpe_sequential(self, cpu: Device, gpu: Device):
        self.setup_loss()
        self.setup_batched_steppers()
        nt = len(self.B_params_list) - 1

        
        B_params = jax.device_put(self.B_params_list[-1], gpu) 
        B_opt_state = jax.device_put(self.B_opt_state, gpu)
        L_params = jax.device_put(self.L_params_list[-1], gpu) 
        L_opt_state = jax.device_put(self.L_opt_state, gpu)
        samples = jax.device_put(self.all_samples['learned'][-1], gpu) 
        SDE_samples = jax.device_put(self.all_samples['SDE'][-1], gpu) 

        self.compute_epr(B_params, L_params, t=0, samples=samples, SDE_samples=SDE_samples)
        
        with tqdm(range(self.n_time_steps)) as pbar: 
            for step in pbar:
                t = (nt*self.store_fac + step)*self.dt
                pbar.set_description(f"Dynamics: t={t:.3f}") 

                samples, SDE_samples = self.step_samples(step, B_params, L_params, t, samples, SDE_samples) 
                opt_samples = self.setup_learning_samples(samples, SDE_samples)

                ## perform the optimization
                B_loss_value, B_grad_norm = np.inf, np.inf
                L_loss_value, L_grad_norm = np.inf, np.inf
                num_steps_taken = 0
                while (B_grad_norm > self.gtol) or (L_grad_norm > self.gtol):
                    for curr_opt_step in range(self.n_opt_steps): 
                        loss_func_args = opt_samples + self.setup_loss_fn_args(gpu)
                        start_time = time.time()
                        B_params, B_opt_state, B_loss_value, B_grads \
                                = update(
                                        B_params, 
                                        B_opt_state, 
                                        self.B_opt, 
                                        self.B_loss_func, 
                                        loss_func_args
                                    )
                        L_params, L_opt_state, L_loss_value, L_grads \
                                = update(
                                        L_params, 
                                        L_opt_state, 
                                        self.L_opt, 
                                        self.L_loss_func, 
                                        loss_func_args
                                    )
                        end_time = time.time()

                    B_grad_norm = compute_grad_norm(B_grads)
                    L_grad_norm = compute_grad_norm(L_grads)
                    pbar.set_postfix(
                        B_loss=B_loss_value, L_loss=L_loss_value, ltol=self.ltol,
                        B_grad_norm=B_grad_norm, L_grad_norm=L_grad_norm, gtol=self.gtol,
                        step_time=end_time-start_time
                    ) 
                    
                self.compute_epr(B_params, L_params, t, samples, SDE_samples)
                
                if (step+1) % self.store_fac == 0:
                    self.B_params_list.append(jax.device_put(B_params, cpu))
                    self.L_params_list.append(jax.device_put(L_params, cpu))
                    

                if (step+1) % self.save_fac == 0:
                    self.save_data()
        self.save_data()


    def save_data(self):
        data = vars(self).copy()
        pickle.dump(data, open(f'{self.output_folder}/{self.output_name}', 'wb'))




@dataclass
class DenoisingSequentialSBTM(SequentialSBTM):
    noise_fac: float 

    def setup_loss(self):
        @jax.jit
        def score_for_sample(params, sample):
            score = self.mask*self.L_score_network.apply(params, sample)
            
            sample = np.array(sample)   

            lambdaj_ri_matrix = lambdaj[:, None] * ri[None, :]  
            lambdaj_ri_matrix = np.expand_dims(lambdaj_ri_matrix, axis=2)
            expanded_samples = sample + lambdaj_ri_matrix  

            def apply_network(exp_sample):
                return self.L_score_network.apply(params, exp_sample)

            all_scores = (jax.vmap(jax.vmap(apply_network))(expanded_samples)).squeeze()
            
            scores_sumj = np.mean(all_scores, axis=0)
            
            
            return np.sum(self.noise_fac*score**2) + 2 * self.noise_fac* np.squeeze(np.mean(scores_sumj * ri * N_ri * 7 * sig_jump))
            
        def sample_denoising_loss(
            params: Params,
            sample: np.ndarray,
            noise: np.ndarray,
        ) -> float: 
            """
            Compute the denoising loss on a single sample, using antithetic sampling
            over the noise for variance reduction.
            """
            loss = 0
            for sign in [-1, 1]:
                perturbed_sample = sample + self.noise_fac*sign*noise
                score = self.mask*self.B_score_network.apply(params, perturbed_sample)
                loss += np.sum(self.noise_fac*score**2 + (2*D) *sign*score*noise)

            return np.squeeze(loss / 2) 

        self.B_loss_func = lambda B_params, samples, noise: np.mean(vmap(sample_denoising_loss, in_axes=(None, 0, 0))(B_params, samples, noise))  
        self.L_loss_func = lambda L_params, samples, noise: np.mean(vmap(score_for_sample, in_axes=(None, 0))(L_params, samples))

        
        if self.use_SDE and self.use_ODE:
            self.n_train_samples = 2*self.n
        else:
            self.n_train_samples = self.n


    def setup_loss_fn_args(self, gpu: Device) -> Tuple:
        """Set up noise arguments for the loss function. """
        noises = onp.random.randn(self.n_train_samples, self.d*self.N)
        loss_func_args = (jax.device_put(noises, gpu),)
        return loss_func_args



In [None]:
#######################################################################################################################
from scipy.stats import norm
sig_jump = 2/48
mu_jump = 0.1
ri = np.linspace(mu_jump-3*sig_jump, mu_jump+4*sig_jump, 31)
lambdaj = np.linspace(0, 1, 20)
N_ri = norm.pdf(ri, loc=mu_jump, scale=sig_jump)
lambdaj = np.array(lambdaj)  
ri = np.array(ri)           
LAMBDA = 30
N_ri = np.array(N_ri)  * LAMBDA 


######## Configuation Parameters #########
d      = 1
kappa = 1
eta = 1


D      = 2

kbT = 4.114
V0 = 5*4.114
Gamma = 3.25
L = 40

D      = kbT / Gamma

D_sqrt = onp.sqrt(D)
mask   = onp.array([1])
dt     = 1e-3
tf     = 100
n      = 4000
n_time_steps = int(tf / dt)
store_fac = 5


## configure random seed
repeatable_seed = False
if repeatable_seed:
    key = jax.random.PRNGKey(42)
    onp.random.seed(42)
else:
    key = jax.random.PRNGKey(onp.random.randint(10000))


## set up forcing parameters
drift      = example_1
force_args = (kappa,eta,)
force_args = (V0,Gamma,L,)


## initial distribution parameters
sig0 = 1.0
mu0  = np.zeros(d)

### setup optimizer
init_learning_rate = 1e-4
init_ltol = 1e-6
ltol = np.inf
gtol = 0.5
n_opt_steps = 25
n_max_init_opt_steps = int(1e4)


### Set up neural network
n_hidden = 3
n_neurons = 32
act = jax.nn.swish
residual_blocks = False


In [None]:
import os
base_folder = 'lcy/epr-levy/experiments/result'
system_folder = 'example1'
output_folder = os.path.join(base_folder, system_folder)

os.makedirs(output_folder, exist_ok=True)

In [None]:
def construct_simulation(
    learning_rate: float,
    noise_fac: float,
    name_str: str
):
    output_name = f'{name_str}.npy'
    sim = DenoisingSequentialSBTM(
        n_max_init_opt_steps=n_max_init_opt_steps,
        init_learning_rate=init_learning_rate,
        init_ltol=init_ltol,
        sig0=sig0,
        mu0=mu0,
        drift=drift,
        force_args=force_args,
        amp=None,
        freq=None,
        dt=dt,
        D=D,
        D_sqrt=D_sqrt,
        n=n,
        N=1,
        d=d,
        ltol=ltol,
        gtol=gtol,
        n_opt_steps=n_opt_steps,
        learning_rate=learning_rate,
        n_hidden=n_hidden,
        n_neurons=n_neurons,
        act=act,
        residual_blocks=residual_blocks,
        interacting_particle_system=False,
        key=key,
        B_params_list=[],
        L_params_list=[],
        all_samples=dict(),
        output_folder=output_folder,
        output_name=output_name,
        n_time_steps=n_time_steps,
        noise_fac=noise_fac,
        use_ODE=True,
        use_SDE=False,
        store_fac=store_fac,
        save_fac=250,
        means={'SDE': [], 'learned': []},
        covs={'SDE': [], 'learned': []},
        epr_tot=[],
        epr_m=[],
        epr_act=[],
        epr_sys=[],
        mask=mask
    )


    return sim

In [None]:
sys.path.append('../')
def get_simulation_parameters():
    """Set up simulation parameters manually for use in a notebook environment."""
    learning_rate = 1e-6
    noise_fac = 0.01       

    name_str = f'lr={learning_rate}_nf={noise_fac}_tf={tf} copy'
    return learning_rate, noise_fac, name_str

print(jax.devices())
gpu = jax.devices('gpu')[0]
cpu = jax.devices('cpu')[0]