In [1]:
import torch

This notebook summarizes the proposed approach to uncertainty quantification \
and propagation for atomistic modelling. \
A common approach for uncertainty quantification of neural network model predictions is \
to directly estimate the target mean $\mu$ and variance $\sigma^2$ using a neural network. [Nix and Weigend 1994](https://doi.org/10.1109/ICNN.1994.374138)\
\
Machine learning models / neural networks for the modelling of atomistic properties, \
typically aim to predict some structure-wise property, e.g. the energy of a molecule. \
The global predictions are typically obtained by summing over atom-wise local predictions. \
This choice is typically made, to ensure that the model can be applied to systems of different sizes \
and ensures linear scaling of the computational evaluation cost of the model with system size. \

In [None]:
class MVE(torch.nn.Module):
    """ Mean variance estimator

    A neural network that predicts both mean and variance for a given input.

    """

    def __init__(self, n_out=2) -> None:
        super().__init__()

        # n_out is the number of output neurons
        # in this case we predict both mean and variance

        self.nn = torch.nn.Sequential(
                torch.nn.Linear(1, 64),
                torch.nn.Tanh(),
                torch.nn.Linear(64, n_out),
                )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:

        out = self.nn(x)
        mean = out[:,0]
        var = out[:,1]

        # forces predicted variance to be positive 
        # by applying a softplus to the NN output
        # alternatively the log variances could be directly predicted

        var = torch.nn.functional.softplus(var)

        return mean, var


class ShallowEnsemble(MVE):
    """ Shallow ensemble model
    """

    def __init__(self) -> None:
        super().__init__(n_out=64)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.nn(x)
        mean = torch.mean(out axis=1)
        var = torch.var(out, axis=1)
        
        return mean, var

    def return_committee(self, x: torch.Tensor) -> torch.Tensor:
        return self.nn(x)   



In [None]:
def z(x):
    return x**2 