# Fourier construction of single dipole sources

This notebook demonstrates that how the dipole field can be constructed with good accuracy from the Fourier model. It does not attempt to learn any source dependence; only the spatial position.

In [None]:
import jax.random as jr
import jax
import jax.numpy as jnp
import optax
from hypermagnetics import plots
from hypermagnetics.sources import configure
from hypermagnetics.models.hyper_fourier import FourierModel
from hypermagnetics.runner import fit
import matplotlib.pyplot as plt
import scienceplots  # noqa

plt.style.use(["science", "ieee"])

## I. Target definition and data generation

We use the typical form of the field from a single source. The dipole field at $\mathbf{r}$ from a source with moment $\mathbf{m}_i$ and positions $\mathbf{r}_0$ is computed via the scalar potential as $${\mu_0}\mathbf{H}_{\odot}(\mathbf{r}) = -\nabla \underbrace{\overbrace{\frac{1}{2\pi |\mathbf{r}-\mathbf{r_0}|}}^{\text{Surface of 2D ball}}\overbrace{\frac{\mathbf{m}_0\cdot(\mathbf{r}-\mathbf{r_0})}{|\mathbf{r}-\mathbf{r_0}|}}^{\text{dipole term}}}_{\text{scalar potential }\psi_i}$$. 

We'll use a single source example on a grid for training and generalise to non-grid locations.

In [None]:
source_config = {
    "n_samples": 1,
    "lim": 2,
    "res": 50,
}
source = configure(**source_config, n_sources=1, key=jr.PRNGKey(40))

In [None]:
class FourierDecomposition(FourierModel):
    hypermodel: jax.Array
    lfmin: jax.Array
    lfmax: jax.Array
    bias: jax.Array
    order: int

    def __init__(self, order):
        self.order = order
        self.lfmin = jnp.ones(1) * -order / 2
        self.lfmax = jnp.ones(1)
        self.bias = jnp.ones(1)
        self.hypermodel = jnp.ones((4 * order**2))  # No hypermodel, just a set of weights

    def prepare_weights(self, sources):
        # Weights don't depend on source information; just fit them.
        return self.hypermodel, self.bias.squeeze()

In [None]:
model = FourierDecomposition(order=25)
plots(source, model)

In [None]:
schedule = [
    {"log_learning_rate": -2.0, "epochs": 5000},
    {"log_learning_rate": -3.0, "epochs": 5000},
    {"log_learning_rate": -4.0, "epochs": 5000}
]

for trainer_config in schedule:
    learning_rate = 10 ** trainer_config["log_learning_rate"]
    optim = optax.adamw(learning_rate, b1=0.95)
    model = fit(trainer_config, optim, model, source, source, every=100)

In [None]:
print(model.lfmin, model.lfmax)
plots(source, model=model)