# GPU Acceleration

This notebook is intended to demonstrate GPU-acceleration of `superscreen` models.

In [None]:
%config InlineBackend.figure_formats = {"retina", "png"}
%matplotlib inline

import os

os.environ["OPENBLAS_NUM_THREADS"] = "1"

import logging

logging.basicConfig(level=logging.INFO)

import jax
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output

plt.rcParams["figure.figsize"] = (8, 6)
plt.rcParams["font.size"] = 14

import superscreen as sc
from superscreen.geometry import circle, box

## Superconducting ring

In [None]:
length_units = "um"
ro = 3  # outer radius
ri = 1  # inner radius
slit_width = 0.25
layer = sc.Layer("base", london_lambda=0.100, thickness=0.025, z0=0)

ring = circle(ro)
hole = circle(ri)
bounding_box = sc.Polygon("bounding_box", layer="base", points=circle(1.2 * ro))

In [None]:
device = sc.Device(
    "ring",
    layers=[sc.Layer("base", london_lambda=0.100, thickness=0.025, z0=0)],
    films=[sc.Polygon("ring", layer="base", points=ring)],
    holes=[sc.Polygon("hole", layer="base", points=hole)],
    abstract_regions=[bounding_box],
    length_units=length_units,
)

In [None]:
device.make_mesh(min_points=4_000, optimesh_steps=10)

In [None]:
fig, ax = device.draw(exclude="bounding_box", legend=True)

In [None]:
xs = np.linspace(-3.5, 3.5, 401)

cross_section_coords = [
    # [x-coords, y-coords]
    np.stack([xs, 0 * xs], axis=1),  # horizontal cross-section
    np.stack([xs, -2 * np.ones_like(xs)], axis=1),  # horizontal cross-section
    np.stack([0 * xs, xs], axis=1),  # vertical cross-section
]

### Trapped flux

We can also solve for the field and current distribution from circulating currents associated with flux trapped in the hole.

We assume there is a total current of 1 mA circulating clockwise in the ring (associated with some positive net trapped flux), and that there is otherwise no applied magnetic field. From here we can calculate the current distribution in the ring, the total magnetic field in the plane of the ring, and the flux through the ring.

Note that, although here we are assuming no applied field, we can also solve models with both trapped flux and applied fields.

In [None]:
circulating_currents = {"hole": "1 mA"}
kwargs = dict(
    circulating_currents=circulating_currents,
    field_units="mT",
    current_units="mA",
)

#### NumPy (CPU)

In [None]:
%timeit sc.solve(device,**kwargs); clear_output(wait=True)

#### JAX (CPU)

In [None]:
with jax.default_device(jax.devices("cpu")[0]):
    %timeit sc.solve(device, gpu=True, **kwargs); clear_output(wait=True)

#### JAX (GPU if available)

In [None]:
if "cpu" in jax.devices()[0].device_kind:
    print("Skipping because there is no GPU available.")
else:
    %timeit sc.solve(device, gpu=True, **kwargs); clear_output(wait=True)

In [None]:
solution = sc.solve(device, **kwargs)[-1]

fig, axes = solution.plot_fields(
    cross_section_coords=cross_section_coords[:1], figsize=(6, 8)
)

### Solve for a specific fluxoid state: $\Phi^f=n\Phi_0$

Current and field distributions for a given fluxoid state $\Phi^f=n\Phi_0$, where $\Phi_0$ is the superconducting flux quantum, can be modeled by adjusting the circulating current $I_\mathrm{circ}$ to realize the desired fluxoid value. This calculation is performed by the function `superscreen.find_fluxoid_solution()`.

Here we solve for the current distribution in the ring for the $n=0$ fluxoid state (i.e. Meissner state), which can be achieved by cooling the ring through its superconducting transition with no applied field. If a small field is then applied, it is screened by the ring such that the fluxoid remains zero.

In [None]:
kwargs = dict(
    fluxoids=dict(hole=0),
    applied_field=sc.sources.ConstantField(1),
    field_units="mT",
    current_units="mA",
)

#### NumPy (CPU)

In [None]:
%timeit sc.find_fluxoid_solution(device, **kwargs); clear_output(wait=True)

#### JAX (CPU)

In [None]:
with jax.default_device(jax.devices("cpu")[0]):
    %timeit sc.find_fluxoid_solution(device, gpu=True, **kwargs); clear_output(wait=True)

#### JAX (GPU if available)

In [None]:
if "cpu" in jax.devices()[0].device_kind:
    print("Skipping because there is no GPU available.")
else:
    %timeit sc.find_fluxoid_solution(device, gpu=True, **kwargs); clear_output(wait=True)

In [None]:
# n = 0 fluxoid state, apply a field of 1 mT
solution, result = sc.find_fluxoid_solution(device, **kwargs)
I_circ = solution.circulating_currents["hole"]
fluxoid = sum(solution.hole_fluxoid("hole")).to("Phi_0").magnitude
print("Root finding result:\n", result)
print(f"Total circulating current: {I_circ:.3f} mA.")
print(f"Total fluxoid: {fluxoid:.6f} Phi_0.")

In [None]:
fig, axes = solution.plot_fields(
    cross_section_coords=cross_section_coords[:1], figsize=(6, 8)
)

In [None]:
fig, axes = solution.plot_currents(
    cross_section_coords=cross_section_coords[:1], figsize=(6, 8)
)

## Film with multiple holes

Here we simulate a device with fewer symmetries than the ring, namely a rectangular film with two off-center rectangular holes.

In [None]:
length_units = "um"

layers = [
    sc.Layer("base", Lambda=0.1, z0=0),
]

films = [
    sc.Polygon("film", layer="base", points=box(8, 4)),
]

holes = [
    sc.Polygon("hole0", layer="base", points=box(5, 1, center=(0.5, -0.25))).resample(
        101
    ),
    sc.Polygon("hole1", layer="base", points=box(1, 2.5, center=(-3, 0.25))).resample(
        51
    ),
]

abstract_regions = [
    sc.Polygon("bounding_box", layer="base", points=box(9, 5)),
]

device = sc.Device(
    "rect",
    layers=layers,
    films=films,
    holes=holes,
    abstract_regions=abstract_regions,
    length_units=length_units,
)

In [None]:
fig, ax = device.draw(exclude="bounding_box")

In [None]:
device.make_mesh(min_points=4_000, optimesh_steps=None)

In [None]:
fig, ax = device.plot(mesh=True)
_ = ax.set_title(
    f"Mesh: {device.points.shape[0]} points, " f"{device.triangles.shape[0]} triangles"
)

### Full mutual inductance matrix

#### NumPy (CPU)

In [None]:
%timeit device.mutual_inductance_matrix(units="pH"); clear_output(wait=True)

#### JAX (CPU)

In [None]:
with jax.default_device(jax.devices("cpu")[0]):
    %timeit device.mutual_inductance_matrix(units="pH", gpu=True); clear_output(wait=True)

#### JAX (GPU if available)

In [None]:
if "cpu" in jax.devices()[0].device_kind:
    print("Skipping because there is no GPU available.")
else:
    %timeit device.mutual_inductance_matrix(units="pH", gpu=True); clear_output(wait=True)

In [None]:
M = device.mutual_inductance_matrix(units="pH")
print(f"Mutual inductance matrix shape:", M.shape)
display(M)

As promised, the mutual inductance matrix is approximately symmetric:

In [None]:
asymmetry = float(np.abs((M[0, 1] - M[1, 0]) / min(M[0, 1], M[1, 0])))
print(f"Mutual inductance matrix fractional asymmetry: {100 * asymmetry:.3f}%")

### Model both holes in the $n=0$ fluxoid state

In [None]:
kwargs = dict(
    fluxoids=dict(hole0=0, hole1=0),
    applied_field=sc.sources.ConstantField(1),
    field_units="mT",
    current_units="mA",
)

#### NumPy (CPU)

In [None]:
# n = 0 fluxoid state, apply a field of 1 mT
%timeit sc.find_fluxoid_solution(device, **kwargs); clear_output(wait=True)

#### JAX (CPU)

In [None]:
%%timeit
with jax.default_device(jax.devices("cpu")[0]):
    # n = 0 fluxoid state, apply a field of 1 mT
    %timeit sc.find_fluxoid_solution(device, gpu=True, **kwargs); clear_output(wait=True)

#### JAX (GPU if available)

In [None]:
if "cpu" in jax.devices()[0].device_kind:
    print("Skipping because there is no GPU available.")
else:
    %timeit sc.find_fluxoid_solution(device, gpu=True, **kwargs); clear_output(wait=True)

In [None]:
# n = 0 fluxoid state, apply a field of 1 mT
solution, result = sc.find_fluxoid_solution(
    device,
    fluxoids=dict(hole0=0, hole1=0),
    applied_field=sc.sources.ConstantField(1),
    field_units="mT",
    current_units="mA",
)
clear_output(wait=True)

In [None]:
I_circ = solution.circulating_currents
fluxoids = [
    sum(solution.hole_fluxoid(hole)).to("Phi_0").magnitude
    for hole in ("hole0", "hole1")
]
print("Least-squares minimization result:\n", result)
print(f"Total circulating current: {I_circ} mA.")
print(f"Total fluxoid: {fluxoids} Phi_0.")

In [None]:
fig, axes = solution.plot_fields(figsize=(8, 3))
fig, axes = solution.plot_currents(figsize=(8, 3))

In [None]:
sc.version_table()