In [None]:
import os
import numpy as np
import jax
import jax.numpy as jnp
import arviz as az

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

Basic EIS Routines

In [None]:
def Warburg_element(freq_hz, log_tau0, alpha):
    """
    Single Warburg element:
      Z_W = 1 / [ (j * ω * tau0)^alpha ].

    alpha is sampled from an Exponential prior (alpha_w).
    """
    tau0 = jnp.exp(log_tau0)
    omega = 2.0 * jnp.pi * freq_hz
    return 1.0 / ((1j * omega * tau0) ** alpha)

def HN_element(freq_hz, w, phi, alpha, log_tau0):
    """
    Single HN arc:
      Z_k = w / [1 + (j * 2π f * tau0)^phi]^alpha
    """
    tau0 = jnp.exp(log_tau0)
    omega = 2.0 * jnp.pi * freq_hz
    return w / ((1.0 + (1j * omega * tau0)**phi) ** alpha)

def stick_breaking(beta):
    """
    Standard stick-breaking to produce pi_1,...,pi_K from Beta draws.
    """
    pi_list = []
    prod = 1.0
    for b in beta:
        portion = b * prod
        pi_list.append(portion)
        prod *= (1.0 - b)
    # last stick
    pi_list.append(prod)
    return jnp.stack(pi_list)

def fct_Z_Warburg_exact(freq_vec, alpha_w, tau0_w):
    """
    Single Warburg element for data generation:
      Z_W = 1 / [(j * omega * tau0)^alpha_w].
    """
    omega = 2.0 * np.pi * freq_vec
    return 1.0 / ((1j * omega * tau0_w) ** alpha_w)

def fct_Z_HN_exact(freq_vec, R_inf, R_ct, phi, alpha, tau0):
    """
    Exact single HN arc + R_inf for data generation.
    """
    omega = 2.0 * np.pi * freq_vec
    return R_inf + R_ct / ((1.0 + (1j * omega * tau0)**phi) ** alpha)

Warburg & HN DRT Routines

In [None]:
def fct_gamma_Warburg_exact(tau_vec, log_tau0, alpha):
    """
    Warburg DRT with general alpha (not necessarily 0.5):
      gamma_W(tau) = (1/π) sin(απ) * (τ/τ0)^α
    """
    tau0 = np.exp(log_tau0)
    return (1.0 / np.pi) * np.sin(alpha * np.pi) * (tau_vec / tau0)**alpha

def theta_HN(tau, tau0, phi):
    """
    Helper function to compute the argument for HN DRT.
    """
    return np.arctan2(
        np.sin(np.pi * phi),
        ((tau / tau0)**phi + np.cos(np.pi * phi))
    )

def fct_gamma_HN_exact(tau_vec, R_ct, phi, alpha, tau0):
    """
    Distribution of Relaxation Times for a single HN arc.
    gamma(tau) = R_ct/pi * ( (tau/tau0)^(alpha*phi)*sin(alpha*theta_HN ) )
                         / [1 + 2 cos(pi phi)(tau/tau0)^phi + (tau/tau0)^(2 phi)]^(alpha/2).
    """
    ratio_phi = (tau_vec / tau0)**phi
    prefactor = R_ct / np.pi
    theta_val = theta_HN(tau_vec, tau0, phi)

    numerator = ratio_phi**(alpha * phi) * np.sin(alpha * theta_val)
    denominator = (1.0 + 2.0*np.cos(np.pi*phi)*ratio_phi + ratio_phi**2)**(alpha/2.0)
    return prefactor * (numerator / denominator)

In [None]:
def dp_warburg_hn_model(freq_data, Z_re_data, Z_im_data, K_W=5, K_HN=5):
    """
    Single model that includes:
      - A 'Warburg' truncated DP with K_W arcs (no separate amplitude), 
        but each uses the same alpha_w ~ Exponential(1.0).
      - A 'HN' truncated DP with K_HN arcs.
      - A Beta variable p_W that splits the total arc mixture fraction
        between the two families.

    The total impedance is:
      Z(f) = R_inf
             + sum_{k=1}^{K_W}[  pi_w[k] * Warburg_element(freq, log_tau0_w[k], alpha_w) ]
             + sum_{m=1}^{K_HN}[ pi_hn[m]* HN_element(freq, w_hn[m], phi_hn[m], alpha_hn_[m], log_tau0_hn[m]) ].

    The pi_w[...] sum to p_W, and the pi_hn[...] sum to (1 - p_W).
    """
    # Global series resistance
    R_inf = numpyro.sample("R_inf", dist.HalfNormal(50.0))

    # Beta variable controlling fraction of arcs allocated to Warburg vs HN
    p_w = numpyro.sample("p_W_fraction", dist.Beta(1.0, 1.0))

    # ------------------ DP for Warburg arcs ------------------
    alpha_w = numpyro.sample("alpha_w", dist.Exponential(1.0))
    beta_w  = numpyro.sample("beta_w", 
        dist.Beta(jnp.ones(K_W - 1), alpha_w * jnp.ones(K_W - 1))
    )
    pi_w_raw = stick_breaking(beta_w)
    pi_w = pi_w_raw * p_w

    log_tau0_w = numpyro.sample("log_tau0_w", dist.Normal(-1.0, 2.0).expand([K_W]))

    # ------------------ DP for HN arcs -------------------
    alpha_hn = numpyro.sample("alpha_hn", dist.Exponential(1.0))
    beta_hn  = numpyro.sample("beta_hn",
        dist.Beta(jnp.ones(K_HN - 1), alpha_hn * jnp.ones(K_HN - 1))
    )
    pi_hn_raw = stick_breaking(beta_hn)
    pi_hn = pi_hn_raw * (1.0 - p_w)

    w_hn   =       numpyro.sample("w_hn",    dist.HalfNormal(50.0).expand([K_HN]))
    phi_hn =       numpyro.sample("phi_hn",  dist.Uniform(0.0, 1.0).expand([K_HN]))
    alpha_hn_ =    numpyro.sample("alpha_hn_arcs", dist.Uniform(0.0, 1.0).expand([K_HN]))
    log_tau0_hn =  numpyro.sample("log_tau0_hn", dist.Normal(0.0, 2.0).expand([K_HN]))

    # Noise parameters
    sigma_re = numpyro.sample("sigma_re", dist.HalfNormal(0.3))
    sigma_im = numpyro.sample("sigma_im", dist.HalfNormal(0.3))

    def single_freq_impedance(freq):
        Z_sum = R_inf + 0.0j

        # Sum of Warburg arcs
        for k in range(K_W):
            Z_k_w = Warburg_element(freq, log_tau0_w[k], alpha_w)
            Z_sum += pi_w[k] * Z_k_w

        # Sum of HN arcs
        for m in range(K_HN):
            Z_m_hn = HN_element(freq, w_hn[m],
                                phi_hn[m], alpha_hn_[m],
                                log_tau0_hn[m])
            Z_sum += pi_hn[m] * Z_m_hn

        return Z_sum

    Z_pred = jax.vmap(single_freq_impedance)(freq_data)
    Z_re_pred = Z_pred.real
    Z_im_pred = Z_pred.imag

    # Likelihood
    numpyro.sample("obs_re", dist.Normal(Z_re_pred, sigma_re), obs=Z_re_data)
    numpyro.sample("obs_im", dist.Normal(Z_im_pred, sigma_im), obs=Z_im_data)

1) Generates a 2-arc synthetic impedance data set (HN model) and its exact DRT.

In [None]:
# Frequency axis
N_freqs = 81
f_min, f_max = 1e-2, 1e6
freq_vec = np.logspace(np.log10(f_min), np.log10(f_max), num=N_freqs)

# Warburg parameters
alpha_W  = 0.5
tau0_W   = 1e-2

# HN arc params
R_inf_HN = 10.0
R_ct_HN  = 20.0
phi_HN   = 0.8
alpha_HN = 0.8
tau0_HN  = 1e-3

# Build partial impedances
Z_w  = fct_Z_Warburg_exact(freq_vec, alpha_W, tau0_W)
Z_hn = fct_Z_HN_exact(freq_vec, R_inf_HN, R_ct_HN, phi_HN, alpha_HN, tau0_HN)

# Combine arcs
Z_exact = Z_w + Z_hn

# Add synthetic Gaussian noise
sigma_noise = 0.2
np.random.seed(1234)
noise_r = np.random.normal(0, sigma_noise, N_freqs)
noise_i = np.random.normal(0, sigma_noise, N_freqs)
Z_exp = Z_exact + noise_r + 1j*noise_i

# Build tau_vec for the DRT
N_tau = 901
# tau_vec = np.logspace(-np.log10(f_max), -np.log10(f_min), num=N_tau)
tau_vec = np.logspace(-6.5, 2.5, num=N_tau)

# Warburg DRT 
gamma_w  = fct_gamma_Warburg_exact(tau_vec, np.log(tau0_W), alpha_W)
# HN DRT
gamma_hn = fct_gamma_HN_exact(tau_vec, R_ct_HN, phi_HN, alpha_HN, tau0_HN)
gamma_exact = gamma_w + gamma_hn

2) Run truncated DP-HN with K arcs

In [None]:
# 2) Fit with the combined DP-Warburg-HN model
nuts_kernel = NUTS(dp_warburg_hn_model)
mcmc = MCMC(nuts_kernel, num_warmup=800, num_samples=1500, num_chains=3)
mcmc.run(
    jax.random.PRNGKey(0),
    freq_data=jnp.array(freq_vec),
    Z_re_data=jnp.array(Z_exp.real),
    Z_im_data=jnp.array(Z_exp.imag),
    K_W=10,     # number of truncated arcs for Warburg
    K_HN=10
)

# Posterior
posterior_samples = mcmc.get_samples()
az_data = az.from_numpyro(mcmc)

3) Save all data into 'dp_hn_results' instead of 'dp_zarc_results'

In [None]:
# print("\n===== Posterior Summary =====")
# print(az.summary(
#     az_data,
#     var_names=[
#         "R_inf","p_W_fraction",
#         "alpha_w","beta_w","log_tau0_w",
#         "alpha_hn","beta_hn","w_hn","phi_hn","alpha_hn_arcs","log_tau0_hn",
#         "sigma_re","sigma_im"
#     ],
#     round_to=3
# ))

# # 3) Save results
# os.makedirs("dp_warburg_hn_results", exist_ok=True)
# np.save("dp_warburg_hn_results/posterior_samples.npy",
#         {k: np.array(v) for k, v in posterior_samples.items()})
# np.save("dp_warburg_hn_results/freq_vec.npy", freq_vec)
# np.save("dp_warburg_hn_results/Z_exp.npy",  Z_exp)
# np.save("dp_warburg_hn_results/Z_exact.npy", Z_exact)
# np.save("dp_warburg_hn_results/tau_vec.npy", tau_vec)
# np.save("dp_warburg_hn_results/gamma_exact.npy", gamma_exact)

# print("Sampling complete. Results in dp_warburg_hn_results/*.npy")