### Isothermal JMAK Function 
$$
X(t,T) = 1 - \exp{-b(T)^n(t - t_{inc}(T))^n}
t_{inc}= A_1 \exp{B_1/T} \qquad b(T)  = A_2 \exp{B_2/T} \\
a_1 = \ln{A_1},a_2 = \ln{A_2}
$$

### Nonisothermal JMAK Function
$$
Y(t)  = \begin{cases}
0 \qquad & t \leq t_0 \\ 
Y^{\dagger}(t - t_0) \qquad & t > t_0 \\ 
\end{cases}
$$

$$
\begin{matrix}
t_0 = \rho^{-1}(1) \qquad \rho = \int \frac{d \tau}{t_{inc}(T(\tau))}\\ 
Y^{\dagger}(t) = 1 - \exp{\left\lbrace - \left( \int b(T(\tau)) d\tau \right)^n  \right\rbrace } 
\end{matrix}
$$

### Inverted Isothermal JMAK function 
$$
t_{iso}(X = X^*,T) = t_{inc}(T) + \frac{1}{b(T)}\left(-\ln{(1 - X^*)}\right)^{1/n} 
$$ 

Clearly for every $X^*$ and specified operating time $\tau_{op}$ we can find a $T_{iso}$ such that $t_{iso}(X = X^*,T = T_{iso}) = \tau_{op}$

We want to examine how transient events can impact the computed $T_{iso}$. 

## Recrystallization

### Inverted Non-isothermal JMAK Function
$$
t_{a}(X = X^*,T) = \tau_{a,inc} + \tau_{a,rx} \\ 
$$

where $\tau_{a,inc}$ solves: 

$$
\int_{0}^{\tau_{a,inc}} \frac{dt}{t_{inc}(T(t))} = 1
$$

and $\tau_{a,rx}$ solves 
$$
\left(-\ln{(1 - X^*)}\right)^{1/n} = \int_{0}^{\tau_{a,rx}} b(T(t)) dt
$$


Let $\mathcal{T}_{inc} = \{t: 0 \leq t \leq \tau_{a,inc}\}$ and similarly $\mathcal{T}_{rx} = \{t: 0 \leq t \leq \tau_{a,rx}\}$. Assume that each of these sets contains $N_{d,inc}$ and $N_{d,rx}$ sets each with 

$D_i = {s_i \leq t \leq t + \tau_r}, i = 1,...N_{d,\cdot}$ and further that $s_{i} \geq s_{i-1} + \tau_r \forall i = 2,... N_{d,\cdot}$ and $0 < s_{i} < \tau_{op} - \tau_r \forall i$. define $\mathcal{T}^d_{inc} = \cup_{i = 1}^{N_{d,inc}} D_i$ and a similar defintion for $\mathcal{T}_{rx}^d$ for the $D_i$ during recrystallization. The complements are $\mathcal{T}_{inc}^s = \mathcal{T}_{inc} \backslash \mathcal{T}_{inc}^d$ and $\mathcal{T}_{rx}^s = \mathcal{T}_{rx} \backslash \mathcal{T}_{rx}^d$

On each $D_i$ the temperature profile may be written as $T_r(t;T_a) = T_{a} + \Delta T_r(t)$ and assume that the profile is $C^1(\mathcal{T}_{inc})$ and $C^1(\mathcal{T}_{rx})$. on $\mathcal{T}_{inc}^s$ and $\mathcal{T}_{rx}^s$ the temperature is constant at $T_{a}$. The _goal_ is to solve for $T_a$.

Notice that the ordering of the sets, and their spacing does not impact the overall integral evalutions in the non-isothermal jmak equations. In fact, we may write: 

$$
\int_{0}^{\tau_{a,rx}} b(T(t)) dt = \int_{0}^{\tau_{a,rx} - N_{d,rx} \tau_{a,rx}} b(T_a) dt + N_d \int_{0}^{\tau_r} b(T_r(t;T_a)) dt = \left(\tau_{a,rx} - N_{d,rx} \tau_{a,rx} \right) b(T_a) + N_d \int_{0}^{\tau_r} b(T_r(t;T_a)) dt \\ 
$$

and 

$$
\int_{0}^{\tau_{a,inc}} \frac{dt}{t_{inc}(T(t))} = \frac{\tau_{a,inc} - N_{d,inc}\tau_r}{t_{inc}(T_a)} + N_{d,inc} \tau_r \int_{0}^{\tau_r} \frac{1}{t_{inc}(T_r(t; T_a))} dt
$$

with the important implication being that we only have to evaluate the integrals for a single disturbance temperature profile, not over the whole set of time. 

$$
\begin{matrix}
\tau_{op} = \tau_{a,inc} + \tau_{a,rx}  \qquad \text{equivalency to the isothermal case} \\ 
N_d = N_{d,inc} + N_{d,rx} \qquad \text{total number of cycles} \\ 
\frac{N_{d,inc}}{N_{d,rx}} = \alpha \lceil \frac{\tau_{a,inc}}{\tau_{a,rx}} \rceil \qquad \text{approximately evenly distributed cycles.}
\end{matrix}
$$
The intent is to specify the temperature  profile as mostly constant with some number $N_{d}$ of transient/disturbance periods and see how $T_a$ is effected relative to $T_{iso}$ as $N_{d}$ increases. 


The parameters  $a_1,B_2,a_2,B_2,n$, are estiamted from data and specified.

The parameters $X^*,\tau_{op}$ are known and specified

I approximate and specify $\Delta T_r(t), \tau_r$, 

I want to vary $\tau_{op}$ and $N_d$ and study the effect they have. $\alpha$ is sort-of a nuisance parameter because the effect of the disturbance temperatures may be somewhat different for the incubation and recrystallization time periods. 

In [142]:
import jax
import jax.numpy as jnp
from jax import jit, grad,vmap
from jaxopt import Broyden, Bisection
from dataclasses import dataclass
from typing import Tuple
from typing import Callable
import warnings 

__all__ = [
    "JMAKParams",
    "invert_isothermal_T",
    "solve_Ta_and_tau_inc",
]


@dataclass
class JMAKParams:
    """Physical and operational parameters for the non‑isothermal JMAK solver."""

    # Kinetic parameters (Arrhenius form)
    n: float   # Avrami exponent
    a1: float  # ln A1 for growth
    B1: float  # activation term for growth (K)
    a2: float  # ln A2 for  incubation
    B2: float  # activation term for incubation (K)

    # Target & schedule
    X_star: float  # required transformed fraction (0 < X_star < 1)
    tau_op: float  # total operation time (s)

    # Disturbance specification
    tau_r: float         # duration of a single disturbance window (s)
    N_d: int
    alpha: float
    deltaT_r: jnp.ndarray  # temperature *offset* profile for one window (shape (Nt,))
    t: jnp.ndarray         # time profile for one window (shape (Nt,))

    # Numerical safety bounds
    Tmin: float = 600.0   # K
    Tmax: float = 4000.0  # K

    # ---------- kinetic helpers ----------
    def t_inc(self, T: jnp.ndarray) -> jnp.ndarray:
        """Incubation time (s)."""
        return jnp.exp(self.a2 + self.B2 / T)

    def b(self, T: jnp.ndarray) -> jnp.ndarray:
        """Growth‑rate coefficient (1/s^n)."""
        return jnp.exp(self.a1 + self.B1 / T)

    @property
    def ln_term(self) -> jnp.ndarray:
        """Natural log term for the recrystallization rate equation."""
        return (-jnp.log1p(-self.X_star)) ** (1.0 / self.n)
# -----------------------------------------------------------
# One‑window integrals (re‑evaluated each time Ta changes)
# -----------------------------------------------------------

def _single_window_integrals(p: JMAKParams):
    """Return jit‑compiled integrators for one disturbance window."""

    if p.tau_r == 0.0 or p.N_d == 0 or p.deltaT_r is None:
        # Degenerate case – no disturbance → zero contribution
        return lambda Ta: 0.0, lambda Ta: 0.0

    @jit
    def I_inc_r(Ta):
        T_prof = Ta + p.deltaT_r
        return jnp.trapezoid(1.0 / p.t_inc(T_prof), x=p.t)

    @jit
    def I_rx_r(Ta):
        T_prof = Ta + p.deltaT_r
        return jnp.trapezoid(p.b(T_prof), x=p.t)

    return I_inc_r, I_rx_r


# -----------------------------------------------------------
# One‑window integrals (re‑evaluated each time T_a changes)
# -----------------------------------------------------------

def _single_window_integrals(p: JMAKParams):
    """Return jit‑compiled integrators for one disturbance window."""

    if p.tau_r == 0.0 or p.N_d == 0 or p.deltaT_r is None:
        # Degenerate case – no disturbance → zero contribution
        return lambda Ta: 0.0, lambda Ta: 0.0

    @jit
    def I_inc_r(Ta):
        T_prof = Ta + p.deltaT_r
        return jnp.trapezoid(1.0 / p.t_inc(T_prof), x=p.t)

    @jit
    def I_rx_r(Ta):
        T_prof = Ta + p.deltaT_r
        return jnp.trapezoid(p.b(T_prof), x=p.t)

    return I_inc_r, I_rx_r


# -----------------------------------------------------------
# Residual vector for Broyden (F1, F2, F3) = 0
# -----------------------------------------------------------

def _make_residuals(p: JMAKParams,
                    eps_t: float = 1e-16):
    """Return bounded residuals using a sigmoidal re‑parameterisation.

    We optimise unconstrained variables z = [theta, rho, kappa] in R^3 and map
    them to the physical domain via the logistic function:

        T_a      =  T_min  + logistic(theta) * (T_max - T_min)
        tau_inc  =  logistic(rho)   * tau_op
        N_d_inc  =  logistic(kappa) * N_d

    This guarantees that every function evaluation stays inside the admissible
    box – no clipping, no penalties, no NaNs.
    """

    I_inc_r, I_rx_r = _single_window_integrals(p)
    ln_term = p.ln_term

    logistic = lambda x: 1.0 / (1.0 + jnp.exp(-x))

    @jit
    def residuals(z):
        theta, rho, kappa = z

        # ---------- inverse transform ----------
        Ta       = p.Tmin + logistic(theta) * (p.Tmax - p.Tmin)
        tau_inc  = logistic(rho)   * p.tau_op
        N_d_inc  = logistic(kappa) * p.N_d

        tau_rx   = p.tau_op - tau_inc
        N_d_rx   = p.N_d   - N_d_inc

        # avoid division by zero inside F3 (does not affect root)
        tau_rx = jnp.where(tau_rx < eps_t,  eps_t, tau_rx)
        N_d_rx = jnp.where(N_d_rx <  eps_t,  eps_t, N_d_rx)

        # ---------- residuals ----------
        F1 = (
            (tau_inc - N_d_inc * p.tau_r) / p.t_inc(Ta)
            + N_d_inc * I_inc_r(Ta)
            - 1.0
        )

        # rescale F2 to O(1): use relative error to ln_term
        num_rx = (tau_rx - N_d_rx * p.tau_r) * p.b(Ta) + N_d_rx * I_rx_r(Ta)
        F2 = num_rx / ln_term - 1.0 

                # scale F3 by (N_d + 1) * tau_op to keep magnitude ~1
        F3 = (N_d_inc * tau_rx - p.alpha * tau_inc * N_d_rx) / ( (p.N_d + 1.0) * p.tau_op )
        return jnp.array([F1, F2, F3])

    return residuals


# -----------------------------------------------------------
# Solver using jaxopt.Broyden (quasi‑Newton root finder)
# -----------------------------------------------------------

def solve_Ta_and_tau_inc(
    params: JMAKParams,
    x0: Tuple[float, float, float] = (1200.0, 0.5, 0.0),  # physical guess
    tol: float = 1e-3,
    maxiter: int = int(1e4),
    eps_t: float = 1e-16
):
    """Root‑solve using Broyden in *unconstrained* space.

    Parameters
    ----------
    params   : JMAKParams
    x0_phys  : initial guess in physical variables (T_a [K], tau_inc [s], N_d_inc)
    tol      : tolerance on ‖F‖₂
    maxiter  : maximum Broyden iterations
    """

    T0, tau_inc0, N_d_inc0 = x0

    # logit function (inverse of logistic)
    def logit(y):
        return jnp.log(y / (1.0 - y))

    # map initial guess to unconstrained variables
    theta0 = logit((T0 - params.Tmin) / (params.Tmax - params.Tmin + eps_t))
    rho0   = logit(tau_inc0 / (params.tau_op + eps_t))
    kappa0 = logit(N_d_inc0 / (params.N_d + eps_t)) if params.N_d > 0 else 0.0

    residuals = _make_residuals(params)
    solver = Broyden(fun=residuals, tol=tol, maxiter=maxiter,
                     max_stepsize = 1.,implicit_diff= False)
    sol = solver.run(jnp.array([theta0, rho0, kappa0], dtype=float))

    theta, rho, kappa = sol.params
    logistic = lambda x: 1.0 / (1.0 + jnp.exp(-x))

    Ta_opt      = params.Tmin + logistic(theta) * (params.Tmax - params.Tmin)
    tau_inc_opt = logistic(rho)   * params.tau_op
    N_d_inc_opt = logistic(kappa) * params.N_d

    tau_rx_opt  = params.tau_op - tau_inc_opt
    N_d_rx_opt  = params.N_d   - N_d_inc_opt

    info = {
        "iterations": int(sol.state.iter_num),
        "residual_norm": float(sol.state.error),
    }

    return (
        float(Ta_opt),
        float(tau_inc_opt),
        float(tau_rx_opt),
        float(N_d_inc_opt),
        float(N_d_rx_opt),
        info,
    )





# -----------------------------------------------------------
# Optional: invert isothermal equation for reference temperature
# -----------------------------------------------------------

def invert_isothermal_T(
    params: JMAKParams,
    T_low: float = 200.0,
    T_high: float = 3000.0,
    tol: float = 1e-6,
):
    """Return the isothermal temperature that gives ``X_star`` in ``tau_op`` seconds."""

    ln_term = (-jnp.log1p(-params.X_star)) ** (1.0 / params.n)

    def g(T):
        ff = params.t_inc(T) + ln_term / params.b(T) - params.tau_op
        return ff

    root_solver = Bisection(g, T_low, T_high, tol=tol)
    return float(root_solver.run().params)


### Isothermal Verification

In [145]:
import sys
wlpath = '../recrystallization'
sys.path.append(
    wlpath
)
from external import read_jmak_model_inference,_RX_INF_PATH as rxp

jmak_model = read_jmak_model_inference(
    rxp.joinpath(f'JMAK_Lopez et al. (2015) - MR_trunc_normal_params.csv')
)
a1,B1,a2,B2,n = jmak_model.a1,jmak_model.B1,jmak_model.a2,jmak_model.B2,jmak_model.n
YEAR_TO_SECONDS = 24. * 365.*3600.
N_d = 0
tau_r = 10.
p = JMAKParams(
    n=n,
    a1=a1,
    B1=B1,
    a2=a2,
    B2=B2,
    X_star=0.1,
    tau_op=YEAR_TO_SECONDS,
    tau_r=tau_r,
    N_d = N_d,
    alpha = 1.0,
    deltaT_r=None,
    t = None,
)

T_iso = invert_isothermal_T(p)
Ta = T_iso
tau_inc = p.t_inc(Ta)
tau_rx = p.ln_term/p.b(Ta)
print('incubation time:',  p.t_inc(Ta))
print('recrystallization time:', tau_rx)
N_d_inc = 0.
I_inc_r = 0.
N_d_rx = 0.
I_rx_r = 0.

t = jnp.linspace(0, tau_r, 1000)

print("isothermal temperature:", T_iso - 273.15)
p.deltaT_r = jnp.ones(t.shape[0])*0.
p.t = t
Ta, tau_inc,tau_rx,N_d_inc,N_d_rx,error = solve_Ta_and_tau_inc(p, 
                                                    x0=(T_iso,tau_inc*2.3,0.5*N_d))

print("Solved baseline temperature (C):", Ta - 273.15)
print("Incubation time (s):", tau_inc)
print("Recrystallisation time (s):", tau_rx)
print('N_d_inc:', N_d_inc)
print('N_d_rx:', N_d_rx)
print("Error:", error)
#print("Recrystallisation time (s):", tau_rx)
#print("LBFGS info:", info)



incubation time: 21827.934
recrystallization time: 31514226.0
isothermal temperature: 1086.27724609375
INFO: jaxopt.BacktrackingLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.31620657444000244 Stepsize:1.0  Objective Value:1.9999394416809082  Decrease Error:0.31620657444000244 
INFO: jaxopt.BacktrackingLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.31448233127593994 Stepsize:0.800000011920929  Objective Value:1.9994676113128662  Decrease Error:0.31448233127593994 
INFO: jaxopt.BacktrackingLineSearch: Iter: 3 Minimum Decrease & Curvature Errors (stop. crit.): 0.3107224702835083 Stepsize:0.64000004529953  Objective Value:1.9967097043991089  Decrease Error:0.3107224702835083 
INFO: jaxopt.BacktrackingLineSearch: Iter: 4 Minimum Decrease & Curvature Errors (stop. crit.): 0.29873478412628174 Stepsize:0.5120000243186951  Objective Value:1.9855235815048218  Decrease Error:0.29873478412628174 
INFO: jaxopt.BacktrackingLineSearch: Iter: 5 Mini