# Example using the NN Ensemble class

This is a simple/silly example using an existing NN surrogate of the LCLS injector showing how to set up a NN ensemble in lume-model.

The surrogate can be installed from [https://github.com/slaclab/lcls_cu_injector_ml_model](https://github.com/slaclab/lcls_cu_injector_ml_model). **You will need to have the model files installed locally in order to run this notebook**.



In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from lume_torch.utils import variables_from_yaml
from lume_torch.models import TorchModel
from lume_torch.models.ensemble import NNEnsemble
from lume_torch.variables import DistributionVariable

  from .autonotebook import tqdm as notebook_tqdm


## Create model

In [6]:
path_to_lcls_model = "../../../lcls_cu_injector_ml_model/resources/"

In [8]:
# load transformers
input_sim_to_nn = torch.load(
    path_to_lcls_model + "input_sim_to_nn.pt", weights_only=False
)
output_sim_to_nn = torch.load(
    path_to_lcls_model + "output_sim_to_nn.pt", weights_only=False
)

In [9]:
# load in- and output variable specification
input_variables, output_variables = variables_from_yaml(
    path_to_lcls_model + "model/sim_variables.yml"
)

AttributeError: 'str' object has no attribute 'items'

In [None]:
# Get example inputs ready for test
inputs_small = torch.load(
    path_to_lcls_model + "info/inputs_small.pt", weights_only=False
)
outputs_small = torch.load(
    path_to_lcls_model + "info/outputs_small.pt", weights_only=False
)

input_dict = {}
for i, n in enumerate(input_variables):
    input_dict[n.name] = inputs_small[:, i]

## Create a wrapper around the model to add some random noise to the outputs

Note that without this, since our ensemble is all the same NN, the variance would be zero and the output distribution instantiation would throw an error.

In [None]:
class NoisyLCLSSurrogate(TorchModel):
    """Dumb class to create noisy/variable output
    for the LCLS surrogate model"""

    noise_level: float = 0.01

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _evaluate(self, input_dict):
        # Add random noise to output
        output_dict = super()._evaluate(input_dict)
        noise = np.random.normal(0, self.noise_level)
        output_dict_noisy = {n: t + t * noise for n, t in output_dict.items()}
        return output_dict_noisy

In [None]:
noisy_model = NoisyLCLSSurrogate(
    model=path_to_lcls_model + "model/model.pt",
    input_variables=input_variables,
    output_variables=output_variables,
    input_transformers=[input_sim_to_nn],
    output_transformers=[output_sim_to_nn],
)

## Create ensemble

In [None]:
models_list = [noisy_model] * 10
ensemble_output_variables = [
    DistributionVariable(name="sigma_x"),
    DistributionVariable(name="sigma_y"),
    DistributionVariable(name="sigma_z"),
    DistributionVariable(name="norm_emit_x"),
    DistributionVariable(name="norm_emit_y"),
]

nn_ensemble = NNEnsemble(
    models=models_list,
    input_variables=input_variables,
    output_variables=ensemble_output_variables,
)

## Test on example data

In [None]:
ensemble_out = nn_ensemble.evaluate(input_dict)

In [None]:
ensemble_out

In [None]:
for k, v in ensemble_out.items():
    print(
        k, v.mean[0], v.variance[0]
    )  # example mean/var for first sample of all outputs

In [None]:
# plot example data and predictions
nrows, ncols = 3, 2
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 15))
for i, output_name in enumerate(nn_ensemble.output_names):
    ax_i = ax[i // ncols, i % ncols]
    if i < outputs_small.shape[1]:
        sort_idx = torch.argsort(outputs_small[:, i])
        x_axis = torch.arange(outputs_small.shape[0])
        lower = ensemble_out[output_name].mean[sort_idx] - torch.sqrt(
            ensemble_out[output_name].variance[sort_idx]
        )
        upper = ensemble_out[output_name].mean[sort_idx] + torch.sqrt(
            ensemble_out[output_name].variance[sort_idx]
        )
        ax_i.fill_between(
            x=x_axis, y1=lower, y2=upper, color="C1", label="ensemble predictions"
        )
        ax_i.plot(x_axis, outputs_small[sort_idx, i], "C0x", label="outputs", alpha=0.5)
        ax_i.legend()
        ax_i.set_title(output_name)
ax[-1, -1].axis("off")
fig.tight_layout()