# Advanced Usage  

As the title suggests in this section a series of rarely required but available use-cases are described.

## Real-Fields  

Generally in ultrafast pulse retrieval it is sufficient (and convenient) to describe the pulses as complex fields. However this comes at the drawback. Complex-valued fields are not hermitian. Thus the pulses they describe do not possess negative frequencies and are thus not able to describe nonlinear processes involving difference-frequency generation. This includes explicit DFG-Traces or measurement techniques like [TREX](https://github.com/matillda123/Pulse-Retrieval-with-JAX/blob/main/examples/simulate_and_retrieve_TREX.py), which simultaneous measure multiple nonlinear signals.  
In order to solve this the `pulsedjax.real_fields` module explicitely uses real-valued fields to calculate nonlinear signals. This comes with a series of drawbacks. Most Classical Algorithms are not setup to be used with real fields (this would involve recalculations of analytic gradients, although an algorithms like C-PCGPA might work. But that isn't tested.). Thus only the PIE and General Algorithms are available with `real_fields`. Additionally the presence of negative frequencies requires the usage of a large frequency axis, which increases computational demand. In order to avoid convergence issues, the pulses are only defined on a user-specified frequency range. However the continuous interpolation between different frequency axis requires additional computational ressources.  

Below is a usage example of an SHG-FROG with real fields and the AutoDiff-Solver. The only actual difference is the user-specified input `f_range_fields = (fmin, fmax)`.


In [None]:
from pulsedjax.real_fields import frog
import optax

ad = frog.AutoDiff(delay, frequency, trace, "shg", f_range_fields=(0.1,0.25), solver=optax.adam(learning_rate=0.1))

population = ad.create_initial_population(5, "continuous", "continuous")
final_result = ad.run(population, 50)

## Mixing Algorithms  

All algorithms work based on an `algorithm.step(descent_state, measurement_info, descent_info)` function, which always takes the stated inputs. Thus in principle composite algorithms can easily be created by chaining step functions of different algorithms. However `descent_state` and `descent_info` have different structures and contents between different algorithms. In a composite algorithm these algorithm dependent containers need to be accounted for. In the case of `descent_info` this is not an issue since it is static. For `descent_state` the user has to specify how the population or population individuals are transfered between different variants.  
Below is an example of an algorithm composed of `DifferentialEvolution` and `AutoDiff`. The goal is to refine the fittest individual at each iteration of the `DifferentialEvolution`.

In [None]:
from pulsedjax.frog import DifferentialEvolution, AutoDiff
import optax

from pulsedjax.utilities import scan_helper
from equinox import tree_at
import jax.numpy as jnp
import jax
from jax.tree_util import Partial


de = DifferentialEvolution(delay, frequency_trace, trace, "shg", strategy="best1_exp")
ad = AutoDiff(delay, frequency_trace, trace, "shg", solver=optax.adam(learning_rate=0.1))

# types and number of basis funcs must be the same
# otherwise transfering the population between descent_states becomes difficult
population_de = de.create_initial_population(100, 
                                             amp_type="bsplines_5", phase_type="bsplines_5", 
                                             no_funcs_amp=15, no_funcs_phase=15)

population_ad = ad.create_initial_population(1, 
                                             amp_type="bsplines_5", phase_type="bsplines_5", 
                                             no_funcs_amp=15, no_funcs_phase=15)


descent_state_de, step_de = de.initialize_run(population_de)
descent_state_ad, step_ad = ad.initialize_run(population_ad)
# descent_state_ad is closed over, in order to avoid cross-talk between optimization runs of different individuals

def _step_composite(descent_state_de):

    # make one DifferentialEvolution-Step
    descent_state_de, error_de = step_de(descent_state_de, None)

    # Extract the fittest individual
    idx = de.get_idx_best_individual(descent_state_de)
    fittest_individual = de.get_individual_from_idx(idx, descent_state_de.population)

    # insert this individual into the AutoDiff descent_state
    descent_state_ad_new = tree_at(lambda x: x.population, descent_state_ad, fittest_individual)

    # run an AutoDiff optimization of that individual with 50 iterations
    descent_state_ad_new, error_ad = jax.lax.scan(step_ad, descent_state_ad_new, length=50)

    # return this individual into the population of DifferentialEvolution
    population_de = jax.tree.map(lambda x,y: x.at[idx].set(y[0]), descent_state_de.population, descent_state_ad_new.population)
    descent_state_de = tree_at(lambda x: x.population, descent_state_de, population_de)
    return descent_state_de, jnp.concatenate([error_de, error_ad[-1]])

# convert _composite_step into a lax.scan compatible form
step_composite = Partial(scan_helper, actual_function=_step_composite, number_of_args=1, number_of_xs=0)

# run
descent_state_de, error_arr = jax.lax.scan(step_composite, descent_state_de, length=100)
final_result = de.post_process(descent_state_de, error_arr)

## Adding new Methods  

One idea behind `pulsedjax` is modularity. This should make it relatively easy to add new methods as well as new algorithms. All methods are implemented as `RetrievePulsesMETHOD(RetrievePulses)` in `pulsedjax.core.base_classes_methods.py`. Methods define how data is preprocessed and added into `measurement_info`. In some cases specific post-processing is needd as well. Most importantly a method-class needs to define how the nonlinear signal is calculated via the class-method `RetrievePulsesMETHOD.calculate_signal_t(self, individual, transform_arr, measurement_info)` and `RetrievePulsesMETHOD.generate_signal_t(self, descent_state, measurement_info, descent_info)`, where the latter essentially just applies the to whole population.  
Below is a bare-bone implementation of a `RetrievePulsesMETHOD` class.

In [None]:
class RetrievePulsesMETHOD(RetrievePulses):

    def __init__(self, theta, frequency, measured_trace, nonlinear_method, *args, **kwargs):
        super().__init__(nonlinear_method, *args, **kwargs)

        self.theta, self.time, self.frequency, self.measured_trace, self.central_frequency = self.get_data(theta, frequency, measured_trace)

        self.dt = jnp.mean(jnp.diff(self.time))
        self.df = jnp.mean(jnp.diff(self.frequency))
        self.sk, self.rn = get_sk_rn(self.time, self.frequency)

        self.measurement_info = self.measurement_info.expand(theta = self.theta,
                                                             frequency = self.frequency,
                                                             time = self.time,
                                                             measured_trace = self.measured_trace,
                                                             cross_correlation = self.cross_correlation,
                                                             doubleblind = self.doubleblind,
                                                             dt = self.dt,
                                                             df = self.df,
                                                             sk = self.sk,
                                                             rn = self.rn,
                                                             x_arr = self.x_arr,
                                                             central_frequency = self.central_frequency)
        
        # for delay based methods transform_arr is the same as tau_arr or delay
        # for chirp_scans its the same as phase_matrix
        self.measurement_info = self.measurement_info.expand(transform_arr = self.transform_arr)



    def calculate_signal_t(self, individual, transform_arr, measurement_info):
        signal_t = ... 

        signal_f = self.fft(signal_t, measurement_info.sk, measurement_info.rn)
        signal_t = MyNamespace(signal_t=signal_t, signal_f=signal_f, ... )
        return signal_t
    

    def generate_signal_t(self, descent_state, measurement_info, descent_info):
        transform_arr = measurement_info.transform_arr
        population = descent_state.population
        signal_t = jax.vmap(self.calculate_signal_t, in_axes=(0,None,None))(population, transform_arr, measurement_info)
        return signal_t



    def post_process_get_pulse_and_gate(self, descent_state, measurement_info, descent_info, idx=None):
        ...
        return pulse_t, gate_t, pulse_f, gate_f

In order to create a new final algorithm multiple inheritance is used via `Algorithm(AlgorithmBASE, RetrievePulsesMETHOD)`. Usually some final method specific class-methods have to be defined. 

## Adding new Algorithms  

As stated above it should be relatively easy to add new algorithms. Usually all algorithms inherit from `ClassicalAlgorithmsBASE` or `GeneralAlgorithmsBASE` and are either defined in `pulsedjax.core.base_classic_algorithms.py` or `pulsedjax.core.base_general_optimization.py`.    
All algorithm classes need to possess an `algorithm.step(self, descent_state, measurement_info, descent_info)` function as well as an `algorithm.initialize_run(self, population)` function, where `step()` performs one iteration of the algorithm and `initialize_run()` prepares all provided data for the retrieval.  
Below is a bare-bone implementation of an `AlgorithmBASE` class.

In [None]:
class MyAlgorithmBASE(Classic_or_GeneralAlgorithmsBASE):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self._name = "MyAlgorithm"



    def step(self, descent_state, measurement_info, descent_info):
        """
        Performs one iteration of the Generalized Projection Algorithm.
        
        Args:
            descent_state (Pytree):
            measurement_info (Pytree):
            descent_info (Pytree):

        Returns:
            tuple[Pytree, jnp.array], the updated descent state and the current trace errors of the population.
        """
        
        ... 

        return descent_state, trace_error.reshape(-1,1)
    


    def initialize_run(self, population):
        """
        Prepares all provided data and parameters for the reconstruction. 
        Here the final shape/structure of descent_state, measurement_info and descent_info are determined. 

        Args:
            population (Pytree): the initial guess as created by self.create_initial_population()
        
        Returns:
            tuple[Pytree, Callable], the initial descent state and the step-function of the algorithm.

        """

        # usually the final structure of measurement info is defined through RetrievePulsesMETHOD
        measurement_info = self.measurement_info

        # add the setting attributes into descent_info
        self.descent_info = self.descent_info.expand( ... )
        descent_info = self.descent_info

        # add the initial guess and other changing variables like prng keys into descent_state
        self.descent_state = self.descent_state.expand(population = population, ... )
        descent_state = self.descent_state

        do_step = Partial(self.step, measurement_info=measurement_info, descent_info=descent_info)
        do_step = Partial(scan_helper, actual_function=do_step, number_of_args=1, number_of_xs=0)
        return descent_state, do_step