In [1]:
%reload_ext autoreload
%autoreload 2

import MaNTA
import sys
import toml

In [2]:
import os
os.environ.pop("LD_LIBRARY_PATH", None)

# from JAXTransportSystem import JAXTransportSystem
# from JAXAdjointProblem import JAXAdjointProblem
import jax.numpy as jnp
import jax

from functools import partial

from yancc_wrapper import yancc_wrapper 

from typing import NamedTuple

class StellaratorParams(NamedTuple):
    SourceCenter: float
    SourceHeight: float
    SourceWidth: float
    EdgeTemperature: float
    EdgeDensity: float
    n0: float

    @classmethod
    def from_config(cls, config: MaNTA.TomlValue):
        return cls(
            SourceCenter=config["SourceCenter"],
            SourceHeight=config["SourceHeight"],
            SourceWidth=config["SourceWidth"],
            EdgeTemperature=config["EdgeTemperature"],
            EdgeDensity=config["EdgeDensity"],
            n0=config["n0"]
        )
"""
class StellaratorTransport

Computes sources and neoclassical fluxes (returned from yancc) as required by MaNTA
"""
class StellaratorTransport(MaNTA.TransportSystem): 
    def __init__(self, config , grid: MaNTA.Grid = None):
        MaNTA.TransportSystem.__init__(self)
        self.nVars = 1

        ### Remember to set boundary conditions ####
        self.isUpperDirichlet  = True
        self.isLowerDirichlet  = False

        self.params = StellaratorParams.from_config(config)
        self.yancc_wrapper = yancc_wrapper(self.Density, 1e20, 1e3)
        self.dSigmaFn_dVars = jax.grad(self.SigmaFn, argnums=1)

    def LowerBoundary(self, index, t):
        return 0.0

    def UpperBoundary(self, index, t):
        return 1.5 * self.params.EdgeTemperature * self.Density(0.9)

    #@partial(jax.jit, static_argnums=(0,1))
    def SigmaFn( self, index, state, x, t ):
        f = self.yancc_wrapper.flux(state, x)
        return -f

    @partial(jax.jit, static_argnums=(0,1))
    def Sources( self, index, state, x, t ):
        return self.source(index, state, x, t, self.params)

    
    """
    Sigma and source, and auxilliary functions to be overloaded in derived classes

    Parameters
    ----------
    index : int
        Variable index
    state : dict
        Dictionary containing "Variable", "Derivative, "Flux", "Aux", and "Scalar" arrays
    x : float
        Spatial location
    t : float
        Time
    params : NamedTuple
        Transport system parameters, passed for JAX PyTree compatibility
    Returns
    -------
    float
        Computed sigma or source term
    """
    def sigma( self, index, state, x, t, params: NamedTuple ):
        pass

    def source( self, index, state, x, t, params: NamedTuple ):
        return params.SourceHeight * jnp.exp(-(x - params.SourceCenter)**2 / (2 * params.SourceWidth**2))

    #@partial(jax.jit, static_argnums=(0,1))
    def dSigmaFn_dq( self, index, state, x, t):
        return self.dSigmaFn_dVars(index,state,x,t)["Derivative"]
    
    #@partial(jax.jit, static_argnums=(0,1))
    def dSigmaFn_du( self, index, state, x, t):
        return self.dSigmaFn_dVars(index,state,x,t)["Variable"]
    
    def dSigma_dPhi( self, index, state, x, t):
        return self.dSigmaFn_dVars(index,state,x,t)["Aux"]
    
    @partial(jax.jit, static_argnums=(0,1))
    def dSources_du( self, index, state, x, t ):
        return jax.grad(self.Sources, argnums=1)(index, state, x, t)["Variable"]

    @partial(jax.jit, static_argnums=(0,1))
    def dSources_dq( self, index, state, x, t ):
        return jax.grad(self.Sources, argnums=1)(index, state, x, t)["Derivative"]

    @partial(jax.jit, static_argnums=(0,1))
    def dSources_dsigma( self, index, state, x, t ):
        return jax.grad(self.Sources, argnums=1)(index, state, x, t)["Flux"]
    
    @partial(jax.jit, static_argnums=(0,1))
    def dSources_dPhi( self, index, state, x, t ):
        return jax.grad(self.Sources, argnums=1)(index, state, x, t)["Aux"]
    
    @partial(jax.jit, static_argnums=(0,1))
    def InitialValue( self, index, x ):
        return 1.5 * self.params.EdgeTemperature * self.Density(x)
    
    @partial(jax.jit, static_argnums=(0,1))
    def InitialDerivative( self, index, x ):
        return jax.grad(self.InitialValue, argnums=1)(index, x)
    
    def InitialAuxValue(self, index, x):
        return 0.0

    def Density(self, x):
        return (self.params.n0 - self.params.EdgeDensity) * (1 - x*x) + self.params.EdgeDensity
    
    """
    Create the adjoint problem associated with this transport system
    
    Returns
    -------
    JAXAdjointProblem
        The adjoint problem object
    """
    def createAdjointProblem(self):
        pass


In [3]:

st_config = {
    "SourceCenter": 0.1,
    "SourceHeight": 5.0,
    "SourceWidth": 0.1,
    "EdgeTemperature":0.5,
    "EdgeDensity": 0.1,
    "n0": 0.5,
}

st = StellaratorTransport(st_config)
runner = MaNTA.Runner(st)

Initializing yancc wrapper with parameters:
  nx=5, na=65, nt=17, nz=33
yancc wrapper initialized successfully.


yancc wrapper initialized successfully.


In [4]:
config = {
    "OutputFilename": "stellarator",
    "Polynomial_degree": 4,
    "Grid_size": 3,
    "Lower_boundary": 0.0,
    "Upper_boundary": 0.9,
    "Relative_tolerance": 0.01,
    "tFinal": 2.0,
    "delta_t": 0.5,
    "restart": True,
}

runner.configure(config)

INFO: Using default value for configuration option RestartFile
INFO: Using default value for configuration option solveAdjoint
Total HDG degrees of freedom 49
INFO: Using default value for configuration option Absolute_tolerance
INFO: Using default value for configuration option tau
INFO: Using default value for configuration option tZero
INFO: Using default value for configuration option MinStepSize
INFO: Using default value for configuration option OutputPoints


In [5]:
runner.run()

Setting initial conditions
Evaluating residual
  current residual norm: 1.69591e-131
Evaluating residual
  current residual norm: 0
Updating Jacobian at time 0
Evaluating residual
  current residual norm: 0
Evaluating residual
  current residual norm: 0.00767864
Evaluating residual
  current residual norm: 0.0039321
Evaluating residual
  current residual norm: 0.0039321
Updating Jacobian at time 0


Number of Residual Evaluations due to IDACalcIC 5


Evaluating residual
  current residual norm: 0.00306143
Evaluating residual
  current residual norm: -nan
Updating Jacobian at time 0.0005
Evaluating residual
  current residual norm: 0.0713247
Evaluating residual
  current residual norm: 0.00821832
Evaluating residual
  current residual norm: 0.00350945
Evaluating residual
  current residual norm: 0.013003
E0220 17:00:51.374420  716242 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Evaluating residual
  current residual norm: 0.000654194


Caught exception : XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory

At:
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1306): __call__
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/profiler.py(354): wrapper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(1877): _pjit_call_impl_python
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(179): _python_pjit_helper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(292): cache_miss
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/traceback_util.py(186): reraise_with_filtered_traceback
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/dispatch.py(93): apply_primitive
  /global/homes/e/eatocc

E0220 17:00:51.884590  716242 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Evaluating residual
  current residual norm: 1.65456e-14


Caught exception : XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory

At:
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1306): __call__
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/profiler.py(354): wrapper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(1877): _pjit_call_impl_python
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(179): _python_pjit_helper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(292): cache_miss
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/traceback_util.py(186): reraise_with_filtered_traceback
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/dispatch.py(93): apply_primitive
  /global/homes/e/eatocc

E0220 17:00:52.296349  716242 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Evaluating residual
  current residual norm: 5.57616e-15


Caught exception : XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory

At:
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1306): __call__
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/profiler.py(354): wrapper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(1877): _pjit_call_impl_python
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(179): _python_pjit_helper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(292): cache_miss
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/traceback_util.py(186): reraise_with_filtered_traceback
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/dispatch.py(93): apply_primitive
  /global/homes/e/eatocc

E0220 17:00:52.701075  716242 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Evaluating residual
  current residual norm: 3.15788e-15


Caught exception : XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory

At:
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1306): __call__
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/profiler.py(354): wrapper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(1877): _pjit_call_impl_python
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(179): _python_pjit_helper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(292): cache_miss
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/traceback_util.py(186): reraise_with_filtered_traceback
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/dispatch.py(93): apply_primitive
  /global/homes/e/eatocc

E0220 17:00:53.217040  716242 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Evaluating residual
  current residual norm: 3.78497e-15


Caught exception : XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory

At:
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1306): __call__
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/profiler.py(354): wrapper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(1877): _pjit_call_impl_python
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(179): _python_pjit_helper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(292): cache_miss
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/traceback_util.py(186): reraise_with_filtered_traceback
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/dispatch.py(93): apply_primitive
  /global/homes/e/eatocc

E0220 17:00:53.695220  716242 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Evaluating residual
  current residual norm: 1.15397e-15


Caught exception : XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory

At:
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1306): __call__
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/profiler.py(354): wrapper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(1877): _pjit_call_impl_python
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(179): _python_pjit_helper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(292): cache_miss
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/traceback_util.py(186): reraise_with_filtered_traceback
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/dispatch.py(93): apply_primitive
  /global/homes/e/eatocc

E0220 17:00:54.099834  716242 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Evaluating residual
  current residual norm: 1.57894e-15


Caught exception : XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory

At:
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1306): __call__
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/profiler.py(354): wrapper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(1877): _pjit_call_impl_python
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(179): _python_pjit_helper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(292): cache_miss
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/traceback_util.py(186): reraise_with_filtered_traceback
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/dispatch.py(93): apply_primitive
  /global/homes/e/eatocc

E0220 17:00:54.553079  716242 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory
Evaluating residual
  current residual norm: 2.30794e-15


Caught exception : XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory

At:
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1306): __call__
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/profiler.py(354): wrapper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(1877): _pjit_call_impl_python
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(179): _python_pjit_helper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(292): cache_miss
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/traceback_util.py(186): reraise_with_filtered_traceback
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/dispatch.py(93): apply_primitive
  /global/homes/e/eatocc

E0220 17:00:54.982985  716242 pjrt_stream_executor_client.cc:2916] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory
[ERROR][rank 0][/global/homes/i/iabel/sundials/git/src/ida/ida.c:2385][IDAHandleFailure] At t = 0.001 repeated recoverable residual errors.

SUNDIALS_ERROR: IDASolve() failed with retval = -9



Caught exception : XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to get module function: CUDA_ERROR_OUT_OF_MEMORY: out of memory

At:
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py(1306): __call__
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/profiler.py(354): wrapper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(1877): _pjit_call_impl_python
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(179): _python_pjit_helper
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/pjit.py(292): cache_miss
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/traceback_util.py(186): reraise_with_filtered_traceback
  /global/homes/e/eatocco/.conda/envs/desc-env/lib/python3.12/site-packages/jax/_src/dispatch.py(93): apply_primitive
  /global/homes/e/eatocc

RuntimeError: IDASolve could not complete