# How to use Multiple Devices
## Solving Equilibrium

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../../"))

In [2]:
num_device = 2
from desc import set_device, _set_cpu_count

_set_cpu_count(num_device)
set_device("cpu", num_device=num_device)

In [3]:
# import jax

# jax.config.update("jax_compilation_cache_dir", "./jax-caches")
# jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
# jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

In [4]:
import numpy as np

from desc import config as desc_config
from desc.examples import get
from desc.objectives import *
from desc.objectives.getters import *
from desc.grid import LinearGrid
from desc.backend import jnp
from desc.plotting import plot_grid
from desc.backend import jax
from desc.optimize import Optimizer

DESC version 0.13.0+1543.g3edf125e0.dirty,using JAX backend, jax version=0.5.0, jaxlib version=0.5.0, dtype=float64
Using 2 CPUs:
	 CPU 0: TFRT_CPU_0 with 7.25 GB available memory
	 CPU 1: TFRT_CPU_1 with 7.25 GB available memory


In [5]:
eq = get("HELIOTRON")
eq.change_resolution(3, 3, 3, 6, 6, 6)



In [6]:
obj = get_parallel_forcebalance(eq, num_device=num_device)
cons = get_fixed_boundary_constraints(eq)
for obji in obj.objectives:
    print(jax.devices(desc_config["kind"])[obji._device_id])

Precomputing transforms
Precomputing transforms


When using multiple devices, the ObjectiveFunction will run each 
sub-objective on the device specified in the sub-objective. 
Setting the deriv_mode to 'blocked' to ensure that each sub-objective
runs on the correct device.


TFRT_CPU_0
TFRT_CPU_1


In [7]:
eq.solve(objective=obj, constraints=cons, maxiter=1, ftol=0, gtol=0, xtol=0, verbose=3)

Building objective: lcfs R
Building objective: lcfs Z
Building objective: fixed Psi
Building objective: fixed pressure
Building objective: fixed iota
Building objective: fixed sheet current
Building objective: self_consistency R
Building objective: self_consistency Z
Building objective: lambda gauge
Building objective: axis R self consistency
Building objective: axis Z self consistency
Timer: Objective build = 1.46 sec
Timer: Linear constraint projection build = 4.23 sec
Number of parameters: 76
Number of objectives: 2028
Timer: Initializing the optimization = 5.74 sec

Starting optimization
Using method: lsq-exact
DESC version 0.13.0+1543.g3edf125e0.dirty,using JAX backend, jax version=0.5.0, jaxlib version=0.5.0, dtype=float64
CPU Info:  13th Gen Intel(R) Core(TM) i5-1335U CPU with 6.56 GB available memory
This should run on device id:0
DESC version 0.13.0+1543.g3edf125e0.dirty,using JAX backend, jax version=0.5.0, jaxlib version=0.5.0, dtype=float64
CPU Info:  13th Gen Intel(R) Core

(Equilibrium at 0x7f58e4012a50 (L=3, M=3, N=3, NFP=19, sym=True, spectral_indexing=fringe),
     message: Maximum number of iterations has been exceeded.
     success: False
         fun: [-1.320e-04 -1.389e-04 ...  1.297e-01  9.594e-03]
           x: [-9.209e-03 -1.293e-01 ...  1.671e-02  1.916e-01]
         nit: 1
        cost: 12008336.70921457
           v: [ 1.000e+00  1.000e+00 ...  1.000e+00  1.000e+00]
  optimality: 4357.5087192217225
        nfev: 2
        njev: 2
        allx: [Array([-3.392e-05,  8.921e-06, ...,  0.000e+00,  0.000e+00],      dtype=float64), Array([ 2.764e-03,  2.509e-03, ...,  0.000e+00,  0.000e+00],      dtype=float64)]
       alltr: [Array( 4.615e+06, dtype=float64), np.float64(4614795.672796082)]
     history: [[{'R_lmn': Array([-3.392e-05,  8.921e-06, ...,  0.000e+00,  1.850e-05],      dtype=float64), 'Z_lmn': Array([ 9.011e-06,  1.167e-05, ..., -3.697e-05,  1.686e-05],      dtype=float64), 'L_lmn': Array([-6.194e-07, -1.567e-05, ..., -9.721e-06, -1.466

In [None]:
for obji in obj.objectives:
    plot_grid(obji.constants["transforms"]["grid"])

## Using other Objectives
Above we used the convenience function for force balance objective, but we can also other objectives with this approach. There are some extra steps you need to apply though.

In [None]:
eq = get("HELIOTRON")
eq.change_resolution(3, 3, 3, 6, 6, 6)

In [None]:
grid1 = LinearGrid(
    M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.array([0.2, 0.4]), sym=True
)
grid2 = LinearGrid(
    M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=np.array([0.6, 0.8, 1.0]), sym=True
)

obj1 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid1, device_id=0)
obj2 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid2, device_id=1)
obj3 = AspectRatio(eq=eq, target=8, weight=100, device_id=1)

objs = [obj1, obj2, obj3]
for obji in objs:
    obji.build(verbose=3)
    obji = jax.device_put(obji, jax.devices(desc_config["kind"])[obji._device_id])
    obji.things[0] = eq

objective = ObjectiveFunction(objs)
objective.build(verbose=3)

In [None]:
k = 1
R_modes = np.vstack(
    (
        [0, 0, 0],
        eq.surface.R_basis.modes[np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :],
    )
)
Z_modes = eq.surface.Z_basis.modes[np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :]
constraints = (
    FixBoundaryR(eq=eq, modes=R_modes),
    FixBoundaryZ(eq=eq, modes=Z_modes),
    FixPressure(eq=eq),
    FixPsi(eq=eq),
)
optimizer = Optimizer("lsq-exact")

In [None]:
eq.optimize(
    objective=objective,
    constraints=constraints,
    optimizer=optimizer,
    maxiter=1,
    verbose=3,
    options={
        "initial_trust_ratio": 1.0,
    },
)

## Optimization using Proximal Method

In [None]:
eq = get("precise_QA")
# eq.change_resolution(12, 12, 12, 24, 24, 24)
eq.change_resolution(3, 3, 3, 6, 6, 6)

In [None]:
grid1 = LinearGrid(
    M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=jnp.linspace(0.2, 0.5, 4), sym=True
)
grid2 = LinearGrid(
    M=eq.M_grid, N=eq.N_grid, NFP=eq.NFP, rho=jnp.linspace(0.6, 1.0, 6), sym=True
)

obj1 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid1, device_id=0)
obj2 = QuasisymmetryTwoTerm(eq=eq, helicity=(1, eq.NFP), grid=grid2, device_id=0)
obj3 = AspectRatio(eq=eq, target=8, weight=100, device_id=0)

objs = [obj1, obj2, obj3]
for obji in objs:
    obji.build(verbose=3)
    obji = jax.device_put(obji, jax.devices(desc_config["kind"])[obji._device_id])
    obji.things[0] = eq

objective = ObjectiveFunction(objs)
objective.build(verbose=3)

In [None]:
k = 1
R_modes = np.vstack(
    (
        [0, 0, 0],
        eq.surface.R_basis.modes[np.max(np.abs(eq.surface.R_basis.modes), 1) > k, :],
    )
)
Z_modes = eq.surface.Z_basis.modes[np.max(np.abs(eq.surface.Z_basis.modes), 1) > k, :]
constraints = (
    ForceBalance(eq=eq),
    FixBoundaryR(eq=eq, modes=R_modes),
    FixBoundaryZ(eq=eq, modes=Z_modes),
    FixPressure(eq=eq),
    FixPsi(eq=eq),
    FixCurrent(eq=eq),
)
optimizer = Optimizer("proximal-lsq-exact")

In [None]:
eq.optimize(
    objective=objective,
    constraints=constraints,
    optimizer=optimizer,
    maxiter=1,
    verbose=3,
    options={
        "initial_trust_ratio": 1.0,
    },
)