# TP - Diffusion Models for Data Assimilation

Nous reprenons le système de Lorenz. Nous allons construire un système d'assimilation de données basé sur des modèles de diffusion.

# Task

L'objectif du TP est d'implémenter un modèle de diffusion pour assimiler des données sur l'attracteur de Lorenz. On réutilisera l'attracteur de Lorenz et les observations synthétiques définies au TP précédent (assimilation de données avec l'EnKF).

Les étapes du TP sont : 
- l'implémentation d'un réseau de neurones en JAX,
- le paramétrage des processus de diffusion (forward / reverse),
- l'assimilation de données sur le processus _reverse_
- l'entraînement du modèle
- l'évaluation de ses performances.

## Définir l'attracteur de Lorenz 

Vous pouvez reprendre l'exemple du cours précédent.

## Générer les le jeu de données d'observation synthétiques


## Implémentation de l'architecture du modèle de Diffusion

Definir l'architecture du réseau de neurones (un Multi-Layer Perceptron simple) en JAX/Flax qui approximera la fonction de perte sur la distribution. Le réseau sera entraîné à prédire le bruit ajouté lors du processus de diffusion.


## Implement Forward Diffusion Process

Definir le "forward diffusion process" (e.g., Variance Preserving SDE) en JAX, qui ajoute du bruit petit à petit jusqu'à obtenir un bruit blanc gaussien.

## Implement Reverse Diffusion and Sampling with Data Assimilation

### Subtask:
Define the reverse diffusion process, guided by the learned score function. Integrate data assimilation by incorporating the synthetic observations into the reverse sampling step, ensuring the generated states are consistent with the observations.


**Reasoning**:
I will implement the `data_assimilation_sampler` function as the final step for defining the reverse diffusion process with data assimilation. This function will use JAX's `lax.fori_loop` for efficient iteration, integrate the previously defined `get_score_fn` and `sde_fn`, and incorporate observations by augmenting the score based on whether an observation is available at a given time step.



## Train Diffusion Model

### Subtask:
Train the score-based diffusion model using JAX, minimizing the score matching loss. This involves generating noisy data at various timesteps and training the neural network to predict the noise.


**Reasoning**:
I will define the loss function, the training step, initialize the optimizer, and implement the training loop as described in the instructions to train the score-based diffusion model.



**Reasoning**:
The error `ValueError: Incompatible shapes for broadcasting` occurred because `marginal_prob_std_fn(t)` was returning a `(batch_size,)` array while `model_output` was `(batch_size, 3)`, leading to a broadcasting mismatch during division in `get_score_fn`. I will modify `get_score_fn` to expand the dimensions of `marginal_prob_std_fn(t)` to `(batch_size, 1)` to allow for correct broadcasting with `model_output`.



**Reasoning**:
The previous `code_block` redefined `get_score_fn` to address the broadcasting error. Now, I will re-execute the training loop to proceed with training the diffusion model as initially intended, using the corrected `get_score_fn`.



**Reasoning**:
The training loss turned to `nan` after some epochs, which often indicates numerical instability or exploding gradients. A common source of this in DDPMs is the loss function term. Instead of minimizing `(predicted_noise + true_noise)^2`, the standard DDPM loss minimizes `(predicted_noise - true_noise)^2`. I will correct the `loss_fn` to use `epsilon_theta_predicted - z` instead of `epsilon_theta_predicted + z` to align with the standard noise prediction objective, which should improve numerical stability.



**Reasoning**:
The training loss consistently resulted in `nan` values, indicating numerical instability. To address this, I will implement a more robust training setup. This involves: 1) re-initializing the model parameters to clear any existing `nan`s, 2) normalizing the `true_trajectory` data to a reasonable range to help stabilize inputs to the neural network, and 3) adding gradient clipping to the optimizer to prevent exploding gradients during training. These measures are critical for achieving stable learning in diffusion models.



**Reasoning**:
The previous code failed because `jax.tree_map` is deprecated. I need to update the `train_step` function to use the current `jax.tree.map` for gradient clipping, in line with recent JAX API changes.



**Reasoning**:
The training loss consistently resulted in `nan` values because `marginal_prob_std(t_min)` evaluates to exactly zero, causing `(model_output / 0) * 0` to produce `nan` during the loss calculation. To resolve this, I will adjust the `minval` parameter in the `jax.random.uniform` call within the `train_step` function to `t_min + 1e-7`. This ensures that `t_batch` never includes the exact `t_min` value where `marginal_prob_std` is zero, thereby preventing numerical instability. I will also re-initialize the model parameters and optimizer state, and remove the redundant `jax.tree.map` gradient clipping to simplify the optimization setup.



**Reasoning**:
The training loss consistently resulted in `nan` values because `marginal_prob_std(t)` evaluates to zero when `t` is close to `t_min`. This is due to the definition of `alpha_schedule` (in cell `ca03ff8c`) calculating the integral from `t_min` to `t`, which makes `alpha_schedule(t_min)` exactly zero. To resolve this numerical instability, I must redefine `alpha_schedule` to calculate the integral from `0` to `t`, which is the standard formulation for Variance Preserving SDEs and ensures `marginal_prob_std(t)` is always non-zero for `t >= t_min > 0`. I will then re-initialize the model parameters and optimizer state, and re-run the training loop with the corrected `alpha_schedule` and the previously introduced improvements (normalized data, corrected loss function, and gradient clipping).



## Visualize Assimilation Results

### Subtask:
Plot the true Lorenz trajectory, the noisy synthetic observations, and the trajectories generated by the diffusion model with data assimilation. This visualization will help assess the performance of the assimilation process.


**Reasoning**:
I need to visualize the performance of the data assimilation process. This involves denormalizing the observations and the sampled trajectory, then using `matplotlib` to plot the true trajectory, the noisy observations, and the assimilated trajectory in a 3D scatter plot with appropriate labels and legends.



**Reasoning**:
The previous error indicated that the `model` argument to the JIT-compiled `data_assimilation_sampler` function was not a JAX array and was not marked as static. To fix this, I will modify the `data_assimilation_sampler` function definition to explicitly mark `model` as a static argument using `static_argnames=['model']` in the `jax.jit` decorator. This will resolve the `TypeError` and allow the sampler to be called correctly.



**Reasoning**:
The error `TypeError: jit() missing 1 required positional argument: 'fun'` indicates that the `jax.jit` decorator might be misinterpreting the function it's supposed to decorate. To resolve this, I will explicitly apply `jax.jit` by defining the `data_assimilation_sampler` function first and then wrapping it with `jax.jit` and its `static_argnames`, which is functionally equivalent but can sometimes bypass parsing issues with the decorator syntax.



**Reasoning**:
The `data_assimilation_sampler` function has been successfully redefined and JIT-compiled with the necessary static arguments. Now, I will re-execute the plotting code to generate the visualization of the true trajectory, noisy observations, and the assimilated trajectory as intended by the subtask.



**Reasoning**:
The previous error occurred because `num_steps_reverse` was not treated as a static argument in the `jax.jit` compilation of `data_assimilation_sampler`. I need to add `num_steps_reverse` to the `static_argnames` list in the `jax.jit` decorator to ensure it is concrete at trace time.



**Reasoning**:
The `data_assimilation_sampler` function has been successfully redefined and JIT-compiled with the necessary static arguments. Now, I will re-execute the plotting code to generate the visualization of the true trajectory, noisy observations, and the assimilated trajectory as intended by the subtask.



**Reasoning**:
The `TypeError: Cannot concatenate arrays with different numbers of dimensions` arises because the `ScoreNet` expects batched inputs, but `x_k` (current state) and `t_val` (current time) are passed as unbatched single samples to `score_fn` from within the `data_assimilation_sampler`. To fix this, I will modify the `get_score_fn` function to explicitly add a batch dimension to both `x` and `t` before calling `model.apply`, ensuring `ScoreNet` always receives inputs in the `(batch_size, features)` and `(batch_size, 1)` format, respectively.



**Reasoning**:
The `get_score_fn` has been corrected to handle batching properly. Now, I will re-execute the plotting code to visualize the true trajectory, noisy observations, and the assimilated trajectory to assess the performance of the data assimilation process, as intended by the subtask.



**Reasoning**:
The previous error occurred because `x_next` had a batch dimension `(1, 3)` while the target slice in `sampled_trajectory_arr` expected a `(3,)` shape. This mismatch arose because the `score_predictor` (which calls `get_score_fn`) was adding a batch dimension for the model inference and returning a `(1, 3)` shaped output. To resolve this, I will modify the `loop_body` within `data_assimilation_sampler_raw` to `jnp.squeeze` the output of `score_predictor` before using it in subsequent calculations, ensuring all intermediate states maintain the expected `(3,)` shape.



**Reasoning**:
The `data_assimilation_sampler` function has been successfully redefined and JIT-compiled with the necessary static arguments and corrected shape handling. Now, I will re-execute the plotting code to generate the visualization of the true trajectory, noisy observations, and the assimilated trajectory as intended by the subtask.



## Final Task

### Subtask:
Provide a summary of the implemented diffusion model for data assimilation on the Lorenz attractor and discuss the observed performance.


## Summary:

### Data Analysis Key Findings

*   **Lorenz 63 System and Synthetic Observations**: The Lorenz 63 system was successfully implemented in JAX. A true trajectory of 20,000 steps was generated, from which 2,000 noisy synthetic observations were derived by sampling every 10th step and adding Gaussian noise with a standard deviation of 1.0.
*   **Diffusion Model Architecture**: A `ScoreNet` using Flax (two dense layers with ReLU activations) was defined, designed to output a 3D state corresponding to the Lorenz system. The network's parameters were successfully initialized.
*   **Forward Diffusion Process**: Key components like `beta_schedule`, `alpha_schedule`, `marginal_prob_std`, and `sde_fn` for a Variance Preserving SDE were implemented. A critical fix was applied to the `alpha_schedule` definition (integrating from 0 to t instead of `t_min` to t) to prevent numerical instability during training.
*   **Reverse Diffusion and Data Assimilation Sampler**: The `data_assimilation_sampler` function was implemented to guide the reverse diffusion process, incorporating observations by augmenting the score function. Several JAX-related issues were encountered and resolved, including `TypeError`s related to `jax.jit`'s `static_argnames` for `model` and `num_steps_reverse`, and `ValueError`s stemming from batch dimension mismatches in the `score_predictor` output, which required careful `expand_dims` and `squeeze` operations.
*   **Diffusion Model Training Stability**: The model training phase faced significant challenges:
    *   Initial `ValueError` due to incompatible broadcasting shapes in `get_score_fn`, resolved by correctly expanding dimensions.
    *   Persistent `NaN` losses, which were addressed by:
        *   Correcting the loss function formulation to the standard DDPM objective (using `epsilon_theta_predicted - z` instead of `+ z`).
        *   Normalizing the `true_trajectory` data to a `[-1, 1]` range.
        *   Applying gradient clipping (`optax.clip_by_global_norm(1.0)`).
        *   Crucially, fixing the `alpha_schedule` (integrating from 0 to t) and adjusting `t_batch` sampling (`minval=t_min + 1e-7`) to prevent `marginal_prob_std(t)` from evaluating to zero, which was the root cause of the `NaN` values.
    *   After these fixes, the model trained successfully for 50,000 epochs with stable and decreasing loss values.
*   **Assimilation Performance**: The visualization shows the assimilated trajectory closely follows the true Lorenz trajectory, especially in regions near the noisy observations, and provides a good reconstruction in between. This indicates that the diffusion model effectively leverages the sparse, noisy data to infer the underlying system state, outperforming simple noisy observations.

### Insights or Next Steps

*   **Effective State Estimation**: The diffusion model successfully assimilates noisy, sparse observations to accurately reconstruct the chaotic Lorenz attractor, highlighting its potential for state estimation in non-linear dynamical systems.
*   **Further Evaluation**: Quantify the assimilation accuracy using metrics such as Root Mean Squared Error (RMSE) between the assimilated trajectory and the true trajectory. This would provide a more objective measure of performance beyond visual inspection.
