# Hypernetwork modelling of 3D demagnetising field

In [1]:
import jax.numpy as jnp
import optax

from hypermagnetics import plots
from hypermagnetics.sources import configure
from hypermagnetics.models.hyper_mlp import HyperLayer
from hypermagnetics.runner import fit
from hypermagnetics.measures import loss, accuracy

import matplotlib.pyplot as plt
from matplotlib import patches
import scienceplots  # noqa

plt.style.use(["science", "ieee"])

## I. Target definition and data generation

Here we are modelling the field around three-dimensional prism sources that follow the potential/field derived in the `sources.ipynb` notebook, i.e.

$$
\phi = \mathbf{m}\cdot\left.\left(\mathbb{F}\cdot\mathbf{r}\right)\right|^{\{+a,+b,+c\}}_{\{-a,-b,-c\}},
$$

where the evaluation of the shape tensor $\mathbb{F}|_{abc}$ is implemented via `configure` in `src/sources.py`.

In [2]:
res = 32
lim = 4
row = jnp.linspace(-lim, lim, res)

train_config = {
    "shape": "prism",
    "n_samples": 1000,
    "n_sources": 1,
    "seed": 101,
    "dim": 3,
    "lim": lim,
    "res": res,
}

test_config = train_config.copy()
test_config.update({"seed": 123})

# m, r0 = jnp.split(sources[0][0], 2, axis=-1)

prism_data = configure(**train_config)
test_data = configure(**test_config)

keys = ["sources", "r", "potential_grid", "field_grid"]
sources, r, potential, field = [prism_data.get(key) for key in keys]

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M2 Ultra

systemMemory: 192.00 GB
maxCacheSize: 72.00 GB





In [3]:
model = HyperLayer(in_size=3, width=64, depth=3, hwidth=1, hdepth=3, seed=30)
print(model.hparams)

print(loss(model, prism_data), accuracy(model, prism_data))
# pred = print(jax.vmap(model, in_axes=(0, None))(sources, r))
# P = prism_data["potential"]
# print(P)

{'in_size': 3, 'width': 64, 'depth': 3, 'hwidth': 1, 'hdepth': 3, 'seed': 30}
0.20552762 100.215004


In [4]:
trainer_config = {"epochs": 10000, "params": {"learning_rate": 1e-3}}
optim = optax.adam(**trainer_config["params"])
model = fit(trainer_config, optim, model, prism_data, test_data, every=100)



{'epoch': 0, 'train_loss': 0.20552761852741241, 'train_err': 99.88005828857422, 'test_err': 99.91668701171875}
{'epoch': 100, 'train_loss': 0.11976288259029388, 'train_err': 78.1854476928711, 'test_err': 79.23165893554688}
{'epoch': 200, 'train_loss': 0.07344114780426025, 'train_err': 59.531272888183594, 'test_err': 61.73383712768555}
{'epoch': 300, 'train_loss': 0.058430109173059464, 'train_err': 52.03592300415039, 'test_err': 54.97881317138672}
{'epoch': 400, 'train_loss': 0.048837900161743164, 'train_err': 46.85567855834961, 'test_err': 50.28560256958008}
{'epoch': 500, 'train_loss': 0.04481825232505798, 'train_err': 44.565677642822266, 'test_err': 48.19490051269531}
{'epoch': 600, 'train_loss': 0.04229156672954559, 'train_err': 43.0803337097168, 'test_err': 46.89666748046875}
{'epoch': 700, 'train_loss': 0.04053880646824837, 'train_err': 42.059181213378906, 'test_err': 46.01293182373047}
{'epoch': 800, 'train_loss': 0.039299074560403824, 'train_err': 41.324710845947266, 'test_err':

In [None]:
import numpy as np

idx = 20

m, r0 = jnp.split(sources[idx][0], 2, axis=-1)
x, y, z = jnp.meshgrid(row, row, jnp.linspace(0, 0, 1))
x0, y0, z0 = r0

potential_slice = potential[idx].reshape(res, res, res)[..., res // 2]
field_slice = field[idx].reshape(res, res, res, 3)[..., res // 2, :]

with plt.style.context("science"):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

    # Subplot 1: Contour plot of potential
    ax1.contourf(x[..., 0], y[..., 0], potential_slice, levels=50, cmap="viridis")
    ax1.set_title("Magnetic Scalar Potential")

    # Draw red square
    square = patches.Rectangle(
        (x0 - 1, y0 - 1),
        2,
        2,
        linewidth=1,
        edgecolor="r",
        facecolor="none",
        linestyle="--",
    )
    ax1.add_patch(square)
    ax1.arrow(x0, y0, m[0], m[1], head_width=0.15, head_length=0.15, fc="r", ec="r")

    # Subplot 2: Streamplot of field
    nx = np.array(x[..., 0])
    ny = np.array(y[..., 0])
    Bx = np.array(field_slice[..., 0])
    By = np.array(field_slice[..., 1])
    ax2.streamplot(nx, ny, Bx, By, color="k", density=1.5)
    ax2.set_title("Magnetic Field $\mathbf{H}$")

    plt.tight_layout()
    plt.show()