### Isothermal JMAK Function 
$$
X(t,T) = 1 - \exp{-b(T)^n(t - t_{inc}(T))^n}
t_{inc}= A_1 \exp{B_1/T} \qquad b(T)  = A_2 \exp{B_2/T} \\
a_1 = \ln{A_1},a_2 = \ln{A_2}
$$

### Nonisothermal JMAK Function
$$
Y(t)  = \begin{cases}
0 \qquad & t \leq t_0 \\ 
Y^{\dagger}(t - t_0) \qquad & t > t_0 \\ 
\end{cases}
$$

$$
\begin{matrix}
t_0 = \rho^{-1}(1) \qquad \rho = \int \frac{d \tau}{t_{inc}(T(\tau))}\\ 
Y^{\dagger}(t) = 1 - \exp{\left\lbrace - \left( \int b(T(\tau)) d\tau \right)^n  \right\rbrace } 
\end{matrix}
$$

### Inverted Isothermal JMAK function 
$$
t_{iso}(X = X^*,T) = t_{inc}(T) + \frac{1}{b(T)}\left(-\ln{(1 - X^*)}\right)^{1/n} 
$$ 

Clearly for every $X^*$ and specified operating time $\tau_{op}$ we can find a $T_{iso}$ such that $t_{iso}(X = X^*,T = T_{iso}) = \tau_{op}$

We want to examine how transient events can impact the computed $T_{iso}$. 

## Recrystallization

### Inverted Non-isothermal JMAK Function
$$
t_{a}(X = X^*,T) = \tau_{a,inc} + \tau_{a,rx} \\ 
$$

where $\tau_{a,inc}$ solves: 

$$
\int_{0}^{\tau_{a,inc}} \frac{dt}{t_{inc}(T(t))} = 1
$$

and $\tau_{a,rx}$ solves 
$$
\left(-\ln{(1 - X^*)}\right)^{1/n} = \int_{0}^{\tau_{a,rx}} b(T(t)) dt
$$


Let $\mathcal{T}_{inc} = \{t: 0 \leq t \leq \tau_{a,inc}\}$ and similarly $\mathcal{T}_{rx} = \{t: 0 \leq t \leq \tau_{a,rx}\}$. Assume that each of these sets contains $N_{d,inc}$ and $N_{d,rx}$ sets each with 

$D_i = {s_i \leq t \leq t + \tau_r}, i = 1,...N_{d,\cdot}$ and further that $s_{i} \geq s_{i-1} + \tau_r \forall i = 2,... N_{d,\cdot}$ and $0 < s_{i} < \tau_{op} - \tau_r \forall i$. define $\mathcal{T}^d_{inc} = \cup_{i = 1}^{N_{d,inc}} D_i$ and a similar defintion for $\mathcal{T}_{rx}^d$ for the $D_i$ during recrystallization. The complements are $\mathcal{T}_{inc}^s = \mathcal{T}_{inc} \backslash \mathcal{T}_{inc}^d$ and $\mathcal{T}_{rx}^s = \mathcal{T}_{rx} \backslash \mathcal{T}_{rx}^d$

On each $D_i$ the temperature profile may be written as $T_r(t;T_a) = T_{a} + \Delta T_r(t)$ and assume that the profile is $C^1(\mathcal{T}_{inc})$ and $C^1(\mathcal{T}_{rx})$. on $\mathcal{T}_{inc}^s$ and $\mathcal{T}_{rx}^s$ the temperature is constant at $T_{a}$. The _goal_ is to solve for $T_a$.

Notice that the ordering of the sets, and their spacing does not impact the overall integral evalutions in the non-isothermal jmak equations. In fact, we may write: 

$$
\int_{0}^{\tau_{a,rx}} b(T(t)) dt = \int_{0}^{\tau_{a,rx} - N_{d,rx} \tau_{a,rx}} b(T_a) dt + N_d \int_{0}^{\tau_r} b(T_r(t;T_a)) dt = \left(\tau_{a,rx} - N_{d,rx} \tau_{a,rx} \right) b(T_a) + N_d \int_{0}^{\tau_r} b(T_r(t;T_a)) dt \\ 
$$

and 

$$
\int_{0}^{\tau_{a,inc}} \frac{dt}{t_{inc}(T(t))} = \frac{\tau_{a,inc} - N_{d,inc}\tau_r}{t_{inc}(T_a)} + N_{d,inc} \tau_r \int_{0}^{\tau_r} \frac{1}{t_{inc}(T_r(t; T_a))} dt
$$

with the important implication being that we only have to evaluate the integrals for a single disturbance temperature profile, not over the whole set of time. 

$$
\begin{matrix}
\tau_{op} = \tau_{a,inc} + \tau_{a,rx}  \qquad \text{equivalency to the isothermal case} \\ 
N_d = N_{d,inc} + N_{d,rx} \qquad \text{total number of cycles} \\ 
\frac{N_{d,inc}}{N_{d,rx}} = \alpha \lceil \frac{\tau_{a,inc}}{\tau_{a,rx}} \rceil \qquad \text{approximately evenly distributed cycles.}
\end{matrix}
$$
The intent is to specify the temperature  profile as mostly constant with some number $N_{d}$ of transient/disturbance periods and see how $T_a$ is effected relative to $T_{iso}$ as $N_{d}$ increases. 


The parameters  $a_1,B_2,a_2,B_2,n$, are estiamted from data and specified.

The parameters $X^*,\tau_{op}$ are known and specified

I approximate and specify $\Delta T_r(t), \tau_r$, 

I want to vary $\tau_{op}$ and $N_d$ and study the effect they have. $\alpha$ is sort-of a nuisance parameter because the effect of the disturbance temperatures may be somewhat different for the incubation and recrystallization time periods. 

In [196]:
import jax
import jax.numpy as jnp
from jax import jit, grad,vmap
from jaxopt import LBFGS, Bisection
from dataclasses import dataclass
from typing import Tuple
import optax
import optax.tree_utils as otu
from typing import Callable
import warnings 

__all__ = [
    "JMAKParams",
    "invert_isothermal_T",
    "solve_Ta_and_tau_inc",
]


@dataclass
class JMAKParams:
    """Physical and operational parameters for the non‑isothermal JMAK solver."""

    # Kinetic parameters (Arrhenius form)
    n: float   # Avrami exponent
    a1: float  # ln A1 for growth
    B1: float  # activation term for growth (K)
    a2: float  # ln A2 for  incubation
    B2: float  # activation term for incubation (K)

    # Target & schedule
    X_star: float  # required transformed fraction (0 < X_star < 1)
    tau_op: float  # total operation time (s)

    # Disturbance specification
    tau_r: float         # duration of a single disturbance window (s)
    N_d: int
    alpha: float
    deltaT_r: jnp.ndarray  # temperature *offset* profile for one window (shape (Nt,))
    t: jnp.ndarray         # time profile for one window (shape (Nt,))

    # Numerical safety bounds
    Tmin: float = 600.0   # K
    Tmax: float = 3000.0  # K

    # ---------- kinetic helpers ----------
    def t_inc(self, T: jnp.ndarray) -> jnp.ndarray:
        """Incubation time (s)."""
        return jnp.exp(self.a2 + self.B2 / T)

    def b(self, T: jnp.ndarray) -> jnp.ndarray:
        """Growth‑rate coefficient (1/s^n)."""
        return jnp.exp(self.a1 + self.B1 / T)


# -----------------------------------------------------------
# One‑window integrals (re‑evaluated each time Ta changes)
# -----------------------------------------------------------

def _single_window_integrals(p: JMAKParams):
    """Return functions I_inc_r(Ta) and I_rx_r(Ta) for a single disturbance window."""

    @jit
    def I_inc_r(Ta):
        T_prof = Ta + p.deltaT_r  # (Nt,)
        return jnp.trapezoid(1.0 / p.t_inc(T_prof), x = p.t)

    @jit
    def I_rx_r(Ta):
        T_prof = Ta + p.deltaT_r
        return jnp.trapezoid(p.b(T_prof), x = p.t)

    return I_inc_r, I_rx_r


# -----------------------------------------------------------
# Objective (least‑squares of the two residual equations)
# -----------------------------------------------------------

def _make_objective(p: JMAKParams):
    I_inc_r, I_rx_r = _single_window_integrals(p)
    ln_term = (-jnp.log1p(-p.X_star)) ** (1.0 / p.n)  # constant RHS of F2

    @jit
    def residuals(x):
        
        """Vector of residuals [F1, F2] for decision variables x = [Ta, tau_inc]."""
        Ta, tau_inc,N_d_inc = x  # direct indexing – JAX‑friendly
        # Box constraints as quadratic penalties
        penalty = 0.0
        penalty += jnp.where((Ta < p.Tmin) | (Ta > p.Tmax), 1e6 * (Ta - p.Tmin) ** 2, 0.0)
        penalty += jnp.where((tau_inc <= 0.0) | (tau_inc >= p.tau_op), 1e6 * tau_inc**2, 0.0)

        tau_rx = p.tau_op - tau_inc
        N_d_rx = p.N_d - N_d_inc

        F1 = (
            (tau_inc - N_d_inc * p.tau_r) / p.t_inc(Ta)
            + N_d_inc * I_inc_r(Ta)
            - 1.0
        )

        F2 = (
            (tau_rx - N_d_rx * p.tau_r) * p.b(Ta)
            + N_d_rx * I_rx_r(Ta)
            - ln_term
        )

        F3 = (
            N_d_inc*tau_rx- p.alpha*tau_inc*N_d_rx
        )

        return jnp.array([F1, F2,F3]) + penalty

    @jit
    def objective(x):
        r = residuals(x)
        return 0.5 * jnp.sum(r ** 2)

    return objective


# -----------------------------------------------------------
# Public solver (Optax‑based)
# -----------------------------------------------------------

def run_opt(init_params: jnp.ndarray, 
            fun: Callable,
             opt, 
             max_iter: int, 
             tol: float):

    value_and_grad_fun = optax.value_and_grad_from_state(fun)

    def step(carry):
        params, state = carry
        value, grad = value_and_grad_fun(params, state=state)
        updates, state = opt.update(
            grad, state, params, value=value, grad=grad, value_fn=fun,
            value_and_grad_fun = value_and_grad_fun
        )
        params = optax.apply_updates(params, updates)
        return params, state

    def continuing_criterion(carry):
        _, state = carry
        iter_num = otu.tree_get(state, 'count')
        grad = otu.tree_get(state, 'grad')
        err = otu.tree_l2_norm(grad)
        return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))

    init_carry = (init_params, opt.init(init_params))
    final_params, final_state = jax.lax.while_loop(
        continuing_criterion, step, init_carry
    )
    return final_params, final_state

def solve_Ta_and_tau_inc(
    jparams: JMAKParams,
    x0: Tuple[float, float] = (1000.0, 1e4,10),
    maxiter: int = 5000,
    tol: float = 1e-9,
    learning_rate: float = 1e-1,
    grad_tol: float = 1e-6,
):
    """Solve for baseline temperature *T_a* and incubation time *tau_inc* using Optax.

    Parameters
    ----------
    params : JMAKParams
        All kinetic, target and disturbance data.
    x0 : (Ta0, tau_inc_ratio)
        Initial guesses: temperature in K and fraction of *tau_op* to allocate to incubation.
    maxiter : int
        Maximum optimisation steps.
    tol : float
        Tolerance on the objective value.
    learning_rate : float
        Step size for Adam.
    grad_tol : float
        Stop if L2‑norm of gradient falls below this value.
    """

    Ta0, tau_inc0,N_d_inc = map(float, x0)

    objective = _make_objective(jparams)

    params,final_state = run_opt(jnp.array([Ta0, tau_inc0,N_d_inc]), 
                       objective, 
                        optax.lbfgs(),
                         maxiter, tol)

    Ta_opt, tau_inc_opt,N_d_inc = params
    tau_rx_opt = jparams.tau_op - tau_inc_opt
    N_d_rx = jparams.N_d - N_d_inc  
    # Check for convergence
    if otu.tree_get(final_state, "count") >= maxiter:
        warnings.warn("LBFGS did not converge within maxiter steps.")

    error = objective(params)
    if error > tol:
        warnings.warn("Error: {error} LBGFS did not converge to the specified tolerance.")
    
    return float(Ta_opt), float(tau_inc_opt),float(tau_rx_opt),float(N_d_inc), float(N_d_rx), float(error)



# -----------------------------------------------------------
# Optional: invert isothermal equation for reference temperature
# -----------------------------------------------------------

def invert_isothermal_T(
    params: JMAKParams,
    T_low: float = 200.0,
    T_high: float = 30000.0,
    tol: float = 1e-6,
):
    """Return the isothermal temperature that gives ``X_star`` in ``tau_op`` seconds."""

    ln_term = (-jnp.log1p(-params.X_star)) ** (1.0 / params.n)

    def g(T):
        ff = params.t_inc(T) + ln_term / params.b(T) - params.tau_op
        return ff

    root_solver = Bisection(g, T_low, T_high, tol=tol)
    return float(root_solver.run().params)


### Isothermal Verification

In [199]:
import sys
wlpath = 'E:/ORNL Collaboration/System Design/ASME Code/modeling_tungsten'
sys.path.append(
    wlpath
)


jmak_model = read_jmak_model_inference(
    rxp.joinpath(f'JMAK_Lopez et al. (2015) - MR_trunc_normal_params.csv')
)
a1,B1,a2,B2,n = jmak_model.a1,jmak_model.B1,jmak_model.a2,jmak_model.B2,jmak_model.n
YEAR_TO_SECONDS = 24. * 365.*3600.
N_d = 0
tau_r = 10.
p = JMAKParams(
    n=n,
    a1=a1,
    B1=B1,
    a2=a2,
    B2=B2,
    X_star=0.1,
    tau_op=YEAR_TO_SECONDS,
    tau_r=tau_r,
    N_d = N_d,
    alpha = 1.0,
    deltaT_r=None,
    t = None,
)

T_iso = invert_isothermal_T(p)
t = jnp.linspace(0, tau_r, 1000)

print("isothermal temperature:", T_iso - 273.15)
p.deltaT_r = jnp.ones(t.shape[0])*0.
p.t = t
Ta, tau_inc,tau_rx,N_d_inc,N_d_rx,error = solve_Ta_and_tau_inc(p, 
                                                    x0=(T_iso,0.6*YEAR_TO_SECONDS,0.5*N_d),
                                                    learning_rate = 1e-3,
                                                    maxiter=100)
print('incubation time:',  p.t_inc(Ta))
print("Solved baseline temperature (C):", Ta - 273.15)
print("Incubation time (s):", tau_inc)
print("Recrystallisation time (s):", tau_rx)
print('N_d_inc:', N_d_inc)
print('N_d_rx:', N_d_rx)
print("Error:", error)
#print("Recrystallisation time (s):", tau_rx)
#print("LBFGS info:", info)



isothermal temperature: 1086.27724609375
incubation time: 18921580.0
Solved baseline temperature (C): 723.4843383789063
Incubation time (s): 18921600.0
Recrystallisation time (s): 12614400.0
N_d_inc: -5.484165578828205e-35
N_d_rx: 5.484165578828205e-35
Error: 0.1254045069217682


