### Here we will write the particle filtering class

Existing particle filtering class from the real world functions.

In [None]:
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import jax.scipy as jsp

import equinox as eqx

@dataclass(frozen=True)
class ParticleFilter:
    """
    A frozen dataclass implementing a particle filter with support for different sampling
    and weighting strategies based on time periods (minute, overnight, weekend, other).
    
    This particle filter can handle different market conditions by switching between
    different proposal distributions and weight functions based on a tau_id indicator.
    
    Attributes:
        min_sample_fn (callable): Sampling function for minute-level data
        min_weight_fn (callable): Weight function for minute-level data
        overnight_sample_fn (callable): Sampling function for overnight periods
        overnight_weight_fn (callable): Weight function for overnight periods
        weekend_sample_fn (callable, optional): Sampling function for weekends. 
            Defaults to overnight_sample_fn if None
        weekend_weight_fn (callable, optional): Weight function for weekends.
            Defaults to overnight_weight_fn if None
        other_sample_fn (callable, optional): Sampling function for other periods.
            Defaults to weekend_sample_fn if None
        other_weight_fn (callable, optional): Weight function for other periods.
            Defaults to weekend_weight_fn if None
        final_reweight (callable): Final reweighting function. Defaults to identity
        ESS_COND (float): Effective sample size threshold for resampling. Defaults to 0.5
        N_PARTICLES (int): Number of particles in the filter. Defaults to 2500
        Y_LOOK_FORWARD (int): Number of future observations to look ahead. Defaults to 0
        switch_resampling_in_step (bool): Whether to resample before or after sampling.
            Defaults to False (after sampling)
        needs_final_reweight (bool): Whether final reweighting is needed. Defaults to False
    """
    # Constants
    min_sample_fn: callable
    min_weight_fn: callable
    overnight_sample_fn: callable
    overnight_weight_fn: callable  

    weekend_sample_fn: callable = None
    weekend_weight_fn: callable = None
    other_sample_fn: callable = None
    other_weight_fn: callable = None

    final_reweight: callable = lambda *args: args

    ESS_COND: float = 0.5
    N_PARTICLES: int = 2500
    Y_LOOK_FORWARD: int = 0
    switch_resampling_in_step: bool = False
    needs_final_reweight: bool = False

    def __post_init__(self):
        """
        Post-initialization setup for the particle filter.
        
        Sets default functions for weekend and other periods if not provided,
        and initializes diagnostic names for tracking filter performance.
        """
        # Workaround for frozen=True: use object.__setattr__
        if self.weekend_sample_fn is None:
            object.__setattr__(self, 'weekend_sample_fn', self.overnight_sample_fn)
        if self.weekend_weight_fn is None:
            object.__setattr__(self, 'weekend_weight_fn', self.overnight_weight_fn)
        if self.other_sample_fn is None:
            object.__setattr__(self, 'other_sample_fn', self.weekend_sample_fn)
        if self.other_weight_fn is None:
            object.__setattr__(self, 'other_weight_fn', self.weekend_weight_fn)

        object.__setattr__(self, 'diagnostic_names', ['ess', 'resample_flag', 'normalised_entropy', 'marginal_likelihood'])

    @staticmethod    
    def multinomial_resample(subkey, weights, N_particles):
        """
        Perform multinomial resampling of particles based on their weights.
        
        Args:
            subkey: JAX random key for sampling
            weights: Array of particle weights (must sum to 1)
            N_particles: Number of particles to resample
            
        Returns:
            tuple: (resample_indices, log_weights, resample_flag)
                - resample_indices: Indices of resampled particles
                - log_weights: Log weights after resampling (uniform)
                - resample_flag: 1 if resampling occurred, 0 otherwise
        """
        resample_indices = jax.random.choice(
            subkey, N_particles, p=weights, shape=(N_particles,)
        )
        log_weights = jnp.log(1 / N_particles) * jnp.ones_like(weights)
        return resample_indices, log_weights, 1

    @staticmethod
    def get_online_metric_args(log_weights, unormalised_log_weights, n_particles):
        """
        Calculate online diagnostic metrics for the particle filter.
        
        Args:
            log_weights: Current log weights of particles
            unormalised_log_weights: Unnormalized log weights before normalization
            n_particles: Number of particles
            
        Returns:
            tuple: (entropy, marginal_likelihood)
                - entropy: Normalized entropy of the particle weights
                - marginal_likelihood: Marginal likelihood estimate
        """
        entropy = jnp.sum(jnp.exp(log_weights) * log_weights) + jnp.log(n_particles)
        marginal_likelihood = jnp.sum(jnp.exp(unormalised_log_weights))
        
        return entropy, marginal_likelihood
    
    @staticmethod
    def calculate_ess(log_weights):
        """
        Calculate the effective sample size (ESS) of the particle weights.
        
        Args:
            log_weights: Log weights of particles
            
        Returns:
            float: Effective sample size
        """
        return 1.0 / jnp.exp(jax.scipy.special.logsumexp(2 * log_weights))
    
    def combined_sample_fn(self, tau_id, subkey, particles, Y_array, idt):
        """
        Select and apply the appropriate sampling function based on tau_id.
        
        Args:
            tau_id: Time period identifier (0=minute, 1=overnight, 2=weekend, 3=other)
            subkey: JAX random key for sampling
            particles: Current particle states
            Y_array: Observation array
            idt: Time index
            
        Returns:
            array: New sampled particles
        """
        return jax.lax.switch(
            tau_id,
            [
                lambda: self.min_sample_fn(subkey, particles, Y_array, idt),
                lambda: self.overnight_sample_fn(subkey, particles, Y_array, idt),
                lambda: self.weekend_sample_fn(subkey, particles, Y_array, idt),
                lambda: self.other_sample_fn(subkey, particles, Y_array, idt),
            ]
        )
    
    def combined_weight_fn(self, tau_id, sampled_particles, particles, Y_array, idt):
        """
        Select and apply the appropriate weight function based on tau_id.
        
        Args:
            tau_id: Time period identifier (0=minute, 1=overnight, 2=weekend, 3=other)
            sampled_particles: Newly sampled particles
            particles: Previous particle states
            Y_array: Observation array
            idt: Time index
            
        Returns:
            array: Weight updates for the particles
        """
        return jax.lax.switch(
            tau_id,
            [
                lambda: self.min_weight_fn(sampled_particles, particles, Y_array, idt),
                lambda: self.overnight_weight_fn(sampled_particles, particles, Y_array, idt),
                lambda: self.weekend_weight_fn(sampled_particles, particles, Y_array, idt),
                lambda: self.other_weight_fn(sampled_particles, particles, Y_array, idt),
            ]
        )

    @eqx.filter_jit
    def simulate(
            self,
            key: jax.random.PRNGKey,
            initial_particles: jnp.ndarray, 
            initial_log_weights: jnp.ndarray,
            Y_array: jnp.ndarray,
            tau_id_array: jnp.ndarray,
    ):
        """
        Run the particle filter simulation over the entire time series.
        
        This method implements a sequential Monte Carlo particle filter that can
        switch between different proposal distributions and weight functions based
        on the tau_id_array. It performs resampling when the effective sample size
        falls below the threshold.
        
        Args:
            key: JAX random key for the simulation
            initial_particles: Initial particle states
            initial_log_weights: Initial log weights of particles
            Y_array: Array of observations
            tau_id_array: Array of time period identifiers for each time step
            
        Returns:
            tuple: (final_particles, final_log_weights, filter_diagnostics)
                - final_particles: Final particle states
                - final_log_weights: Final log weights
                - filter_diagnostics: Dictionary containing diagnostic metrics
                    (ess, resample_flag, normalised_entropy, marginal_likelihood)
                    
        Raises:
            AssertionError: If Y_array shape doesn't match tau_id_array shape + Y_LOOK_FORWARD
            NotImplementedError: If final reweighting is requested but not implemented
        """
        
        assert Y_array.shape[0] == tau_id_array.shape[0] + self.Y_LOOK_FORWARD, "The Y_array is not the correct shape relative to the tau_id_array and the Y_LOOK_FORWARD"

        def particle_filter_step(carry, time_slice):
            tau_id, idt = time_slice
            key, Y_array, particles, log_weights = carry
            
            if self.switch_resampling_in_step:
                key, subkey = jax.random.split(key)

                ess = self.calculate_ess(log_weights)
                particle_indices, log_weights, resample_flag = jax.lax.cond(
                    ess/self.N_PARTICLES < self.ESS_COND,
                    lambda k, log_w: self.multinomial_resample(k, jnp.exp(log_w), self.N_PARTICLES),
                    lambda _, log_w: (jnp.arange(self.N_PARTICLES), log_w, 0),
                    *(subkey, log_weights)
                )
                
                particles = sampled_particles[particle_indices]
           

            # 2. Classic particle filter logic

             # Sample new particles using the proposal distribution
            key, subkey = jax.random.split(key)
            sampled_particles = self.combined_sample_fn(tau_id, subkey, particles, Y_array, idt)

            # Update particle weights
            weight_update = self.combined_weight_fn(tau_id, sampled_particles, particles, Y_array, idt)
            unormalised_log_weights = log_weights + weight_update
            log_weights = unormalised_log_weights - jsp.special.logsumexp(unormalised_log_weights)

            online_metric_args = self.get_online_metric_args(log_weights, unormalised_log_weights, self.N_PARTICLES)

            if not self.switch_resampling_in_step:
                key, subkey = jax.random.split(key)

                ess = self.calculate_ess(log_weights)
                particle_indices, log_weights, resample_flag = jax.lax.cond(
                    ess/self.N_PARTICLES < self.ESS_COND,
                    lambda k, log_w: self.multinomial_resample(k, jnp.exp(log_w), self.N_PARTICLES),
                    lambda _, log_w: (jnp.arange(self.N_PARTICLES), log_w, 0),
                    *(subkey, log_weights)
                )
                particles = sampled_particles[particle_indices]

            return (key, Y_array, particles, log_weights), (ess, resample_flag, *online_metric_args,)
        
        # 1. Run scan.

        Y_idt_half_array = jnp.arange(tau_id_array.shape[0])
        Y_idt_array = jnp.column_stack((tau_id_array, Y_idt_half_array))

        # Run filter
        final_carry, online_metric_list = jax.lax.scan(
            particle_filter_step, 
            (key, Y_array, initial_particles, initial_log_weights),
            Y_idt_array
        )

        # 2. Process output of scan.
        _, _, final_particles, final_log_weights = final_carry
        filter_diagnostics = {diagnostic_name: val for diagnostic_name, val in zip(self.diagnostic_names, online_metric_list)}

        # AUX final reweight
        if self.needs_final_reweight:
            raise NotImplementedError("Final reweighting not implemented")
            # We are going to have issues as we need the previous particles, which we dont have without writing them into scan or terminating one early.
            # When re-writing terminate one step early.

        return final_particles, final_log_weights, filter_diagnostics


Existing Simulate Forward Class

In [None]:
from dataclasses import dataclass

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import equinox as eqx

import pf_functions.constants as const


def f_from_noise(last_val, noise):
    """
    Evolve the hidden state using the stochastic volatility model.
    
    Args:
        last_val: Previous hidden state value
        noise: Random noise for the evolution
        
    Returns:
        next_val: Next hidden state value
    """
    means = last_val + const.TAU * const.KAPPA * (const.X_BAR - last_val)
    std = jnp.sqrt(const.TAU) * const.SIGMA_SIGMA
    next_val = noise * std + means
    return next_val

def overnight_f_from_noise(last_val, noise):
    """
    Evolve the hidden state using the stochastic volatility model.
    
    Args:
        last_val: Previous hidden state value
        noise: Random noise for the evolution
        
    Returns:
        next_val: Next hidden state value
    """
    overnight_TAU_adjust = 8 * 60 * 0.4/0.6
    means = last_val + const.TAU * overnight_TAU_adjust * const.KAPPA * (const.X_BAR - last_val)
    std = jnp.sqrt(const.TAU * overnight_TAU_adjust) * const.SIGMA_SIGMA
    next_val = noise * std + means
    return next_val

def g_from_total_noise(prev_particle, all_particles, total_noise):
    """
    Generate observations from hidden states using the observation model.
    
    Args:
        all_particles: Hidden state particles
        total_noise: Random noise for observations
        
    Returns:
        observations: Generated observations
    """
    y_means = -0.5 * const.TAU * jnp.exp(2 * all_particles)
    y_stds = jnp.sqrt(const.TAU) * jnp.exp(all_particles)
    return total_noise * y_stds + y_means

def skew_g_from_total_noise(prev_particle, particle, noise):
    # Mean reversion term: E[X_t | X_{t-1}] = X_{t-1} + κτ(X̄ - X_{t-1})
    mean_reversion_term = prev_particle + const.KAPPA * const.TAU * (const.X_BAR - prev_particle)

    # Mean: μ_Y = -0.5τ exp(2X_t) + ρ(exp(X_t)/σ²)(X_t - E[X_t | X_{t-1}])
    skew_total_mean = -0.5 * const.TAU * jnp.exp(2 * particle) + const.RHO * (jnp.exp(particle) / const.SIGMA_SIGMA) * (particle - mean_reversion_term)
    
    # Variance: σ²_Y = τ exp(2X_t)(1 - ρ²)
    skew_total_var = const.TAU * jnp.exp(2 * particle) * (1 - const.RHO**2)

    return skew_total_mean + noise * jnp.sqrt(skew_total_var)

@dataclass(frozen=True)
class SimulatingForward:
    """
    A class for simulating forward particle paths and evaluating their likelihood.s
    
    This class implements forward simulation of particle paths from initial particle states
    and evaluates the likelihood of observed data using kernel density estimation on
    realized volatility metrics.
    
    Attributes:
        min_f_from_noise: Function to evolve state for minute-level periods
        overnight_f_from_noise: Function to evolve state for overnight periods  
        observation_from_noise: Function to generate observations from state and noise
        weekend_f_from_noise: Function to evolve state for weekend periods (defaults to overnight)
        other_f_from_noise: Function to evolve state for other periods (defaults to weekend)
        N_SAMPLES: Number of forward simulation samples (default: 2500)
    """
    min_f_from_noise: callable
    overnight_f_from_noise: callable
    observation_from_noise: callable
    weekend_f_from_noise: callable = None
    other_f_from_noise: callable = None
    N_SAMPLES: int = 2500

    def __post_init__(self):
        """
        Post-initialization setup for the forward simulator.
        
        Sets default functions for weekend and other periods if not provided.
        """
        # Workaround for frozen=True: use object.__setattr__
        if self.weekend_f_from_noise is None:
            object.__setattr__(self, 'weekend_f_from_noise', self.overnight_f_from_noise)
        if self.other_f_from_noise is None:
            object.__setattr__(self, 'other_f_from_noise', self.weekend_f_from_noise)

    def evaluation_metrics_from_paths(self, sampled_forward_particles, sampled_observation_paths, 
                                    Y_array, tau_id_array, total_tau_time):
        """
        Calculate likelihood of observed data given simulated forward paths.
        
        Uses kernel density estimation on realized volatility to estimate the likelihood
        of the target realized volatility given the distribution of simulated realized volatilities.
        
        Args:
            sampled_forward_particles: Simulated particle states
            sampled_observation_paths: Simulated observation paths
            Y_array: Target observation array
            tau_id_array: Time period identifiers
            total_tau_time: Total time duration in minutes
            
        Returns:
            float: Estimated likelihood of target realized volatility
        """
        # Calculate realized volatility with time normalization
        tau_normalized_constant = 1 / total_tau_time
        target_realized_vol = tau_normalized_constant * jnp.sum(Y_array**2)
        sampled_realized_vol = tau_normalized_constant * jnp.sum(sampled_observation_paths**2, axis=1)

        # Estimate likelihood using kernel density estimation
        kde = jsp.stats.gaussian_kde(sampled_realized_vol, bw_method="scott")
        likelihood = kde.logpdf(target_realized_vol)
    
        return likelihood

    def combined_f_from_noise(self, tau_id, last_val, noise):
        """
        Select and apply the appropriate state evolution function based on tau_id.
        
        Args:
            tau_id: Time period identifier (0=minute, 1=overnight, 2=weekend, 3=other)
            last_val: Previous state value
            noise: Random noise for state evolution
            
        Returns:
            array: New state value
        """
        return jax.lax.switch(
            tau_id,
            [
                lambda: self.min_f_from_noise(last_val, noise),
                lambda: self.overnight_f_from_noise(last_val, noise),
                lambda: self.weekend_f_from_noise(last_val, noise),
                lambda: self.other_f_from_noise(last_val, noise),
            ]
        )

    @eqx.filter_jit
    def simulate_forward(self, key: jax.random.PRNGKey, initial_particles: jnp.ndarray,
                        initial_log_weights: jnp.ndarray, Y_array: jnp.ndarray,
                        tau_id_array: jnp.ndarray, total_tau_time: float):
        """
        Simulate forward particle paths and evaluate their likelihood.
        
        This method samples particles from the initial distribution, simulates their
        forward evolution, generates corresponding observations, and evaluates the
        likelihood of the target observations using realized volatility metrics.
        
        Args:
            key: JAX random key for the simulation
            initial_particles: Initial particle states
            initial_log_weights: Initial log weights of particles
            Y_array: Target observation array
            tau_id_array: Time period identifiers for each time step
            total_tau_time: Total time duration in minutes for normalization
            
        Returns:
            tuple: ((sampled_forward_particles, sampled_observation_paths), likelihood)
                - sampled_forward_particles: Simulated particle state paths
                - sampled_observation_paths: Simulated observation paths
                - likelihood: Estimated likelihood of target observations
        """
        sample_key, path_key, observation_key = jax.random.split(key, 3)

        # Sample starting points from initial particle distribution
        starting_point_indices = jax.random.choice(
            sample_key, initial_particles.shape[0], (self.N_SAMPLES,), 
            p=jnp.exp(initial_log_weights)
        )
        starting_points = initial_particles[starting_point_indices]

        # Define state evolution functions
        def individual_jump_body(carry, time_slice):
            """Single step of state evolution"""
            tau_id, noise = time_slice
            next_val = self.combined_f_from_noise(tau_id, carry, noise)
            return next_val, next_val
    
        def scan_fn(initial_points, tau_id_array, jump_noise):
            """Evolve states over multiple time steps"""
            _, hidden_state_evolution = jax.lax.scan(individual_jump_body, initial_points, (tau_id_array, jump_noise))
            return hidden_state_evolution

        # Generate forward particle paths
        jump_noises = jax.random.normal(path_key, (self.N_SAMPLES, Y_array.shape[0]))
        sampled_forward_particles = jax.vmap(scan_fn, in_axes=(0, None, 0))(
            starting_points, tau_id_array, jump_noises
        )

        # Generate observations from the particle paths
        def observation_jump_body(all_particle_array, time_slice):
            """Single step of state evolution"""
            idt, noise = time_slice
            new_y = self.observation_from_noise(all_particle_array.at[idt-1].get(), all_particle_array.at[idt].get(), noise)
            return all_particle_array, new_y

        def observation_scan_fn(all_points, jump_noise):
            """Evolve states over multiple time steps"""

            # Expand the all_points array.
            all_points = jnp.concatenate((jnp.expand_dims(all_points[0], axis=0), all_points))

            # Create the xs vals, which is a tuple of arrays.
            time_slice = jnp.arange(1, all_points.shape[0], dtype=jnp.int32)

            _, hidden_state_evolution = jax.lax.scan(observation_jump_body, all_points, (time_slice, jump_noise))
            return hidden_state_evolution

        observation_noise = jax.random.normal(observation_key, sampled_forward_particles.shape)
        sampled_observation_paths = observation_scan_fn(sampled_forward_particles, observation_noise)

        # Calculate evaluation metrics
        result_metrics = self.evaluation_metrics_from_paths(
            sampled_forward_particles, sampled_observation_paths, 
            Y_array, tau_id_array, total_tau_time
        )

        return (sampled_forward_particles, sampled_observation_paths), result_metrics


Processing functions, this is where the large changes will take place as we are no longer going through a dataset.

In [None]:
import jax
import jax.numpy as jnp

import polars as pl
from tqdm.notebook import tqdm

import pf_functions.constants as const


def pre_processing(data_slice, forecast_negative_increments: list, n_particles: int = 2500, particle_initialisation_key: jax.random.PRNGKey = jax.random.key(0)):
    """
    Preprocess financial time series data for particle filtering analysis.
    
    This function performs three main preprocessing steps:
    1. Cleans the dataset by removing unused columns and categorizing time intervals (tau)
    2. Creates forecast flags based on specified time horizons before expiry
    3. Initializes particles and weights for particle filtering
    
    Parameters:
    -----------
    data_slice : pl.LazyFrame or pl.DataFrame
        Input financial time series data containing columns: dt, raw_tau, front_expiry, 
        front_implied, back_implied, front_tte, back_tte, back_expiry
    forecast_negative_increments : list
        List of hours before expiry to create forecast flags for. 
        Each value creates a corresponding {i}_f_flag column
    n_particles : int, default=2500
        Number of particles to initialize for particle filtering
    particle_initialisation_key : jax.random.PRNGKey, default=jax.random.key(0)
        Random key for reproducible particle initialization
        
    Returns:
    --------
    tuple : (processed_df, (initial_particles, initial_log_weights))
        - processed_df: pl.DataFrame with preprocessing applied
        - initial_particles: jax.Array of shape (n_particles,) with initial particle states
        - initial_log_weights: jax.Array of shape (n_particles,) with initial log weights
        
    Notes:
    ------
    - Tau categorization: 0 (< 0.1h), 1 (17.5h), 2 (65.5h), 3 (other)
    - Forecast flags are created for each hour in forecast_negative_increments
    - Particles are initialized from Normal(X_BAR, SIGMA_SIGMA) distribution
    - Log weights are initialized uniformly (-log(n_particles))
    """
    if not isinstance(data_slice, pl.LazyFrame):
        data_slice = data_slice.lazy()
    
    # 1. Drop unused columns and calculate tau values
    data_slice = data_slice.drop(['back_implied', 'front_tte', 'back_tte', 'back_expiry'])


    data_slice = data_slice.with_columns([
        pl.when(pl.col("raw_tau").is_null())
          .then(None)
        .when(pl.col("raw_tau") < 0.1)          # less than ~6 mins (0.1 hr)
          .then(0)
        .when(pl.col("raw_tau") == 17.5)
          .then(1)
        .when(pl.col("raw_tau") == 65.5)
          .then(2)
        .otherwise(3)
        .alias("tau_id")
    ])
    
    data_slice = data_slice.filter(~pl.col("log_returns").is_nan()) # Removing the nan values. We do this after so that the gaps are treated as minutes.

    # drop the first row
    data_slice = data_slice.slice(1)

    data_slice_with_tau = data_slice.collect()

    # 2. Create the F_flag columns using forecast_negative_increments
    for i, expiry_hour in enumerate(forecast_negative_increments):
        max_forecast_times = data_slice_with_tau.select(
            (pl.col("front_expiry") - pl.duration(hours=expiry_hour)).alias("forecast_time")
        ).unique().to_series()

        data_slice_with_tau = data_slice_with_tau.with_columns([
            pl.col("dt").is_in(pl.Series(max_forecast_times).implode()).alias(f"{i}_f_flag")
        ])

    f_flag_columns = [f"{i}_f_flag" for i in range(len(forecast_negative_increments))]
    data_slice_with_tau = data_slice_with_tau.with_columns([
        pl.fold(acc=pl.lit(False), function=lambda acc, x: acc | x, exprs=[pl.col(col) for col in f_flag_columns]).alias("any_f_flag")
    ])

    true_count = data_slice_with_tau.filter(pl.col("any_f_flag") == True).height
    print(f"Number of Forecast Points values in any_f_flag: {true_count}")
    print(f"Total height of data frame: {data_slice_with_tau.height}")

    # 3. Create the initial particles and weights
    initial_particles = jax.random.normal(particle_initialisation_key, (n_particles, )) * const.SIGMA_SIGMA + const.X_BAR
    initial_log_weights = -jnp.log(n_particles) * jnp.ones_like(initial_particles)

    return data_slice_with_tau, (initial_particles, initial_log_weights)  # processed_df, (initial_particles, initial_weights)


def processing(key, particle_filter, processed_data_base, particles_weights_tuple, break_after: int = -1, verbose: bool = True):
    """
    Process a time series dataset using a particle filter, segmenting the data at forecast points.
    
    This function iteratively processes segments of data between forecast points (marked by 'any_f_flag')
    using a particle filter. It maintains particle states and weights across segments and collects
    diagnostics for each processing step.
    
    Parameters:
    -----------
    key : jax.random.PRNGKey
        Random number generator key for JAX operations
    particle_filter : ParticleFilter
        Particle filter object with a 'simulate' method that processes data segments
    processed_data_base : polars.DataFrame
        Preprocessed dataset containing columns: 'any_f_flag', 'tau_id', 'log_returns'
    particles_weights_tuple : tuple
        Initial particles and log weights as (particles, log_weights)
    break_after : int, optional (default=-1)
        Number of segments to process before stopping. If -1, processes all segments
    verbose : bool, optional (default=True)
        Whether to display progress bar and detailed output for each segment
        
    Returns:
    --------
    tuple : (diagnostic_from_segment_dict, particle_and_weights_at_flag_idx)
        - diagnostic_from_segment_dict: Dictionary mapping segment indices (start, end) to 
          particle filter diagnostics (ESS, resample flags, marginal likelihoods)
        - particle_and_weights_at_flag_idx: Dictionary mapping forecast point indices to 
          final particle states and weights at those points
    
    Notes:
    ------
    - Segments are defined between consecutive forecast points (where 'any_f_flag' is True)
    - The particle filter processes each segment and maintains state continuity
    - Log returns include lookahead data (Y_LOOK_FORWARD steps) for the particle filter
    - Diagnostics include Effective Sample Size (ESS), resampling flags, and marginal likelihoods
    """
    
    diagnostic_from_segment_dict = {}
    particle_and_weights_at_flag_idx = {}
    
    # Get the row indices where any_f_flag is True
    flag_mask = processed_data_base['any_f_flag']
    flag_indices = [i for i, flag in enumerate(flag_mask) if flag]

    last_idx = 0 # First val of pr
    total_height = processed_data_base.height

    last_particles, last_weights = particles_weights_tuple

    # Create progress bar with dataset progress information
    
    pbar = tqdm(enumerate(flag_indices), total=min(len(flag_indices), break_after)+5, 
                desc="Processing segments", unit="segment")
    

    for go, flag_idx in pbar:

        # 1. Prepare for the simulation.
        segment_to_process = processed_data_base.slice(last_idx + 1, flag_idx - last_idx) # 
        tau_id_segment = segment_to_process['tau_id'].to_jax()

        # Handle log_return slice: add lookahead rows after flag_idx
        log_return_steps = flag_idx - last_idx + particle_filter.Y_LOOK_FORWARD

        if last_idx + log_return_steps > total_height:
            raise NotImplementedError("This did in fact come up as an issue.")
        
        log_return_segment = processed_data_base.slice(last_idx, log_return_steps)['log_returns'].to_jax()
            
        # 2. Run the simulation.
        key, pf_step_key = jax.random.split(key)
        out_particles, out_weights, diagnostics = particle_filter.simulate(
            pf_step_key, 
            last_particles, 
            last_weights,
            log_return_segment,
            tau_id_segment
        )

        if verbose:
            print(f"Segment {go}: Processing from index {last_idx} to {flag_idx} ({(flag_idx/total_height)*100:.1f}% of dataset)")
            print(f"ESS: {jnp.mean(diagnostics['ess']):.3f}, Resample flag: {jnp.mean(diagnostics['resample_flag']):.3f}")
            print(f"Marginal likelihood: {jnp.sum(jnp.log(diagnostics['marginal_likelihood'])):.6f}")
            print("-" * 50)

        # 3. Process the output of the simulation.
        diagnostic_from_segment_dict[(last_idx + 1, flag_idx + 1)] = diagnostics
        particle_and_weights_at_flag_idx[flag_idx] = (out_particles, out_weights) # This is the index of the flag.

        # Update the last particles and weights
        last_particles = out_particles
        last_weights = out_weights
        last_idx = flag_idx

        if break_after - 1 == go:
            break
    
    return diagnostic_from_segment_dict, particle_and_weights_at_flag_idx


Existing Wrapper Class

In [None]:
from dataclasses import dataclass

import jax
import polars as pl

from .particle_filter import ParticleFilter
from .simulate_forward import SimulatingForward
from .processing_functions import pre_processing, processing
from .simulate_forward_processing import simulate_forward_processing

@dataclass(frozen=True)
class RealWorldWrapper:
    """
    A wrapper class for running particle filter simulations on real-world data.
    
    This class provides a convenient interface for loading data, preprocessing it,
    running particle filtering, and performing forward simulations for forecasting.
    
    Attributes:
        simulate_from_func: Function for performing forward simulations
        forecast_negative_increments: List of forecast horizons in hours (negative values)
        path_to_data: Path to the parquet data file
        n_rows_to_process: Number of rows to load from the data file
    """
    forecast_negative_increments: list
    simulate_from_func: SimulatingForward = None
    path_to_data: str = r"C:\Users\chris\OneDrive\Belgeler\ParticleFilter\ParticleFilter\real_data\preprocessed_synth_price_data.parquet"
    n_rows_to_process: int = 100000

    def __post_init__(self):
        """Initialize the wrapper by loading and preprocessing data."""
        object.__setattr__(self, 'loaded_data', pl.read_parquet(
            self.path_to_data, 
            n_rows=self.n_rows_to_process
        ).lazy())

        # Run the pre-processing
        pre_processed_data_base, initial_particles_weights = pre_processing(
            self.loaded_data, 
            self.forecast_negative_increments
        )
        object.__setattr__(self, 'pre_processed_data_base', pre_processed_data_base)
        object.__setattr__(self, 'initial_particles_weights', initial_particles_weights)

    def final_step_processing(self, diagnostic_from_segment_dict, raw_fit_metrics):
        """
        Post-process simulation results. Override this method for custom processing.
        
        Args:
            diagnostic_from_segment_dict: Dictionary containing diagnostic results
            raw_fit_metrics: Array of fit metrics from simulations
            
        Returns:
            Tuple of (processed_diagnostics, processed_fit_metrics)
        """
        return diagnostic_from_segment_dict, raw_fit_metrics

    def run_from_particle_filter(self,
                               key: jax.random.PRNGKey,
                               particle_filter: ParticleFilter,
                               simulate_from_func: SimulatingForward = None,
                               include_raw_paths: bool = False,
                               break_after: int = -1,
                               verbose: bool = True):
        """
        Run the complete particle filter pipeline.
        
        Args:
            key: JAX random key for reproducibility
            particle_filter: ParticleFilter instance to use
            break_after: Maximum number of segments to process (-1 for all)
            verbose: Whether to print progress information
            
        Returns:
            Tuple of (final_diagnostics, final_fit_metrics)
        """
        # Split key for separate processing and forward simulation
        processing_key, forward_key = jax.random.split(key)
        
        # Run particle filtering
        diagnostic_from_segment_dict, particle_and_weights_at_flag_idx = processing(
            processing_key, 
            particle_filter, 
            self.pre_processed_data_base, 
            self.initial_particles_weights, 
            break_after, 
            verbose
        )
        
        # Run forward simulations

        if simulate_from_func is None:
            simulate_from_func = self.simulate_from_func

        raw_path_dict, raw_fit_metrics = simulate_forward_processing(
            forward_key, 
            simulate_from_func, 
            self.pre_processed_data_base, 
            particle_and_weights_at_flag_idx, 
            self.forecast_negative_increments, 
            break_after, 
            verbose
        )
        
        # Apply final processing step
        final_diagnostic_from_segment_dict, final_raw_fit_metrics = self.final_step_processing(
            diagnostic_from_segment_dict, 
            raw_fit_metrics
        )
        if include_raw_paths:
            return final_diagnostic_from_segment_dict, final_raw_fit_metrics, raw_path_dict
        else:
            return final_diagnostic_from_segment_dict, final_raw_fit_metrics


Previoulsy started wrapper class.

In [None]:
import pf_functions.config as config
config.set_constants()

import pf_functions as PF


def model_weight_function(model, particles, inputs):
    pass

def model_sample_from_inputs(model, key, inputs):
    pass


class NeuralNetworkProcess:
    def __init__(self, 
                 initial_model: eqx.Module,
                 data_generation_function: callable,
                 training_data_generation_function: callable,
                 model_weight_function: callable,
                 model_sample_from_inputs: callable, 
                 state_transition_weight: callable,
                 state_observation_weight: callable):


        '''Both data generation should be partialed so that we dont handel the function passing inside the class.
        state and observation should be for singles. '''
        
        self.initial_model = initial_model
        self.model = initial_model
        
        self.data_generation_function = data_generation_function
        self.training_data_generation_function = training_data_generation_function

        self.model_weight_function = model_weight_function
        self.model_sample_from_inputs = model_sample_from_inputs

        self.state_transition_weight = jax.jit(jax.vmap(state_transition_weight, in_axes=(0, 0, None, None)))
        self.state_observation_weight = jax.jit(jax.vmap(state_observation_weight, in_axes=(0, 0, None, None)))
 

    def get_training_data(self, 
                         training_params: tuple[int, int] = (1000, 5000), 
                         testing_params: tuple[int, int] = (100, 2500),
                         training_key: int = 42,
                         testing_key: int = 52,
                         verbose: bool = True):
        """Generate training and testing data with customizable parameters.
        
        Args:
            training_params: (n_batches, batch_size) for training data
            testing_params: (n_batches, batch_size) for testing data  
            training_key: Random key for training data generation
            testing_key: Random key for testing data generation
            verbose: Whether to print progress information
        """
        if verbose:
            print(f"Generating training data: {training_params[0]} batches of size {training_params[1]}")
            print(f"Generating testing data: {testing_params[0]} batches of size {testing_params[1]}")
        
        data_key = jax.random.key(training_key)
        test_data_key = jax.random.key(testing_key)

        input_batches, target_batches = self.training_data_generation_function(
            data_key, 
            training_params[0], 
            training_params[1],
        )

        inputs = jnp.vstack(input_batches)
        targets = jnp.hstack(target_batches)

        test_input_batches, test_target_batches = self.training_data_generation_function(
            test_data_key, 
            testing_params[0], 
            testing_params[1],
        )

        test_inputs = jnp.vstack(test_input_batches)
        test_targets = jnp.hstack(test_target_batches)

        if verbose:
            print(f"Training data shape: inputs {inputs.shape}, targets {targets.shape}")
            print(f"Testing data shape: inputs {test_inputs.shape}, targets {test_targets.shape}")

        return inputs, targets, test_inputs, test_targets


    def train(self, 
              training_params: tuple[int, int] = (1000, 5000), 
              testing_params: tuple[int, int] = (100, 2500),
              steps: int = 5000,
              learning_rate: float = 1e-2,
              batch_size: int = 5000,
              train_key: int = 20,
              eval_frequency: int = 100,
              verbose: bool = True):
        """Train the neural network model with customizable parameters.
        
        Args:
            training_params: (n_batches, batch_size) for training data
            testing_params: (n_batches, batch_size) for testing data
            steps: Number of training steps
            learning_rate: Learning rate for optimizer
            batch_size: Batch size for training
            train_key: Random key for training
            eval_frequency: How often to evaluate on test set
            verbose: Whether to print progress information
        """
        if verbose:
            print(f"Starting training with {steps} steps, lr={learning_rate}, batch_size={batch_size}")
        
        inputs, targets, test_inputs, test_targets = self.get_training_data(
            training_params, testing_params, verbose=verbose
        )

        model = self.model
        train_key = jax.random.key(train_key)

        optimizer = optax.adam(learning_rate=learning_rate)
        opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
        
        @eqx.filter_value_and_grad
        def loss(model, inputs, z_i):
            log_likelihood = self.model_weight_function(model, inputs, z_i)
            return -jnp.mean(log_likelihood)

        @eqx.filter_jit
        def batched_train_step(model, x, y, opt_state, optimizer):
            neg_ll, grads = loss(model, x, y)
            updates, opt_state = optimizer.update(grads, opt_state)
            model = eqx.apply_updates(model, updates)
            return model, opt_state, neg_ll

        @eqx.filter_jit
        def evaluate(model, x, y):
            return loss(model, x, y)

        losses = []
        test_losses = []
        loop = tqdm(range(steps))

        shuffle_key, train_key = jax.random.split(train_key)
        shuffled_ix = jax.random.permutation(shuffle_key, inputs.shape[0])

        batch_ix = 0
        for step in loop:
            if (batch_ix + 1) * batch_size > inputs.shape[0]:
                shuffle_key, train_key = jax.random.split(train_key)
                shuffled_ix = jax.random.permutation(shuffle_key, inputs.shape[0])
                batch_ix = 0
            
            batch_idx = shuffled_ix[batch_ix * batch_size:(batch_ix + 1) * batch_size]
            batch_ix += 1

            model, opt_state, neg_ll = batched_train_step(
                model, inputs[batch_idx], targets[batch_idx], opt_state, optimizer
            )
            losses.append(neg_ll)
            
            # Evaluate on test set periodically
            if step % eval_frequency == 0:
                test_loss = evaluate(model, test_inputs, test_targets)
                test_losses.append(test_loss)
                loop.set_postfix({
                    'train_loss': f'{neg_ll:.4f}',
                    'test_loss': f'{test_loss:.4f}',
                    'step': step
                })
            else:
                loop.set_postfix({'train_loss': f'{neg_ll:.4f}', 'step': step})

        if verbose:
            print(f"\nTraining completed!")
            print(f"Final training loss: {losses[-1]:.4f}")
            if test_losses:
                print(f"Final test loss: {test_losses[-1]:.4f}")
            print(f"Best test loss: {min(test_losses):.4f}" if test_losses else "No test evaluation")

        self.model = model
        self.vectorised_model = jax.vmap(model)
        self.training_losses = losses
        self.test_losses = test_losses

    def build_model_inputs(self, prev_particles, Y_array, idt):
        new_col = jnp.full((prev_particles.shape[0], 1), Y_array.at[idt].get())
        model_input_i = jnp.hstack((prev_particles.reshape(-1, 1), new_col))
        return model_input_i


    def create_NN_weight_and_sample_from_model_and_proposal_for_skew(self, model):
        """
        Creates weight and sampling functions for a neural network-based particle filter.
        
        Args:
            vectorised_model: Neural network model that outputs means and standard deviations
            g_l_single: Single-particle observation likelihood function
            f_l_single: Single-particle transition likelihood function
            
        Returns:
            tuple: (weight_function, sampling_function)
        """

        
        def NN_weight_fn(particles, prev_particles, Y_array, idt): 
            """
            Calculate particle weights using the neural network proposal.
            
            Args:
                Y_i: Current observation
                particles: Current state particles
                prev_particles: Previous state particles
                
            Returns:
                array: Log weights for each particle
            """
            # Calculate observation likelihood
            # I have to change this to be the leverage model, but can I just input the leverage_bootstrap_weight?
            g_y_i_from_x_i = self.state_observation_weight(particles, prev_particles, Y_array, idt)
            
            # Calculate transition likelihood
            f_x_i_from_x_prev_i = self.state_transition_weight(particles, prev_particles, Y_array, idt)
            
            # Calculate proposal likelihood
            inputs = self.build_model_inputs(prev_particles, Y_array, idt)
            q_x_i_from_x_prev_i_y_i = self.model_weight_function(model, particles, inputs)

            # Return log weight: log(g) + log(f) - log(q)
            return g_y_i_from_x_i + f_x_i_from_x_prev_i - q_x_i_from_x_prev_i_y_i
    
        def NN_sample_fn(subkey, prev_particles, Y_array, idt):
            """
            Sample new particles using the neural network proposal.
            
            Args:
                subkey: JAX random key
                particles: Current state particles
                Y_i: Current observation
                
            Returns:
                array: New sampled particles
            """
            inputs = self.build_model_inputs(prev_particles, Y_array, idt)
            new_particles = self.model_sample_from_inputs(model, subkey, inputs)

            return new_particles
        
        return NN_sample_fn, NN_weight_fn
            

    def get_particle_filter(self, N_particles: int,
        input_initial_sample_fn: callable = None,
        input_intial_weight_fn: callable = None,
        input_resample_fn: callable = None,
        vmap: bool = True):

        """
        Creates a neural network-based particle filter from a provided model.
        
        Parameters:
        -----------
        model : eqx.Module
            The neural network model to use for the particle filter
        N_particles : int
            Number of particles to use in the filter
        vmap : bool, optional
            Whether to vectorize the filter, by default True
            
        Returns:
        --------
        callable
            A particle filter function that uses the provided neural network model
        """

        if input_initial_sample_fn is None:
            chosen_initial_sample_fn = PF.particle_filters.initial_sample_fn
        if input_intial_weight_fn is None:
            chosen_intial_weight_fn = PF.particle_filters.intial_weight_fn
        if input_resample_fn is None:
            chosen_resample_fn = PF.particle_filters.multinomial_resample

        vectorised_model = jax.vmap(self.model)
        
        sample_fn, weight_fn = self.create_NN_weight_and_sample_from_model_and_proposal(
            vectorised_model, 
        )

        return PF.particle_filters.create_particle_filter(sample_fn, weight_fn, chosen_initial_sample_fn, chosen_intial_weight_fn, chosen_resample_fn, N_particles, vmap)    

    def evaluate_model(self, N_particles: int):
        # Generate data

        # Load the model particle filter. 
        

        # Load the bootstrap particle filter. 

        # Use Simple plots to 


    def evaluate_model_and_show(self, **kwargs ):
        eval = self.evaluate_model(**kwargs)
