# Prony vs Fractional Models

Compare PyVisco Prony references to RheoJAX generalized and fractional fits.

In [1]:
# Google Colab compatibility - uncomment if running in Colab
# !pip install -q rheojax
# from google.colab import drive
# drive.mount('/content/drive')


## Setup and Imports
Compare Prony-series reference curves against RheoJAX generalized and fractional fits using the time-domain dataset.

In [2]:
# Configure matplotlib for inline plotting in VS Code/Jupyter
%matplotlib inline

import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from rheojax.core.data import RheoData
from rheojax.core.jax_config import safe_import_jax, verify_float64
from rheojax.models.fractional_maxwell_model import FractionalMaxwellModel
from rheojax.models.generalized_maxwell import GeneralizedMaxwell
from rheojax.pipeline.base import Pipeline
from rheojax.transforms.mastercurve import Mastercurve

jax, jnp = safe_import_jax()
verify_float64()
np.set_printoptions(precision=4, suppress=True)
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 11
warnings.filterwarnings('ignore', category=RuntimeWarning)

def r2_complex(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    ss_res = np.sum(np.abs(y_true - y_pred) ** 2)
    ss_tot = np.sum(np.abs(y_true - np.mean(y_true)) ** 2)
    return float(1 - ss_res / ss_tot)


INFO:2025-12-06 04:15:54,714:jax._src.xla_bridge:808: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)


Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file)


Loading rheojax version 0.4.0


  from . import backend, data, dataio, transform


## Load data and Prony terms

In [3]:
DATA_DIR = Path.cwd().parent / 'data' / 'pyvisco' / 'time_master'
relax_df = pd.read_csv(DATA_DIR / 'time_user_master.csv')
relax_clean = relax_df.iloc[1:].astype(float)
t = relax_clean['t'].to_numpy()
E_t = relax_clean['E_relax'].to_numpy()

prony_path = DATA_DIR / 'prony_terms_KPf_MD.csv'
with open(prony_path) as fh:
    lines = [ln.strip() for ln in fh.readlines() if ln.strip()]
E0_line = next((ln for ln in lines if ln.startswith('# E0')), '# E0 = 1.0 MPa')
E0 = float(E0_line.split('=')[1].split()[0])  # MPa

prony_vals = [ln for ln in lines if not ln.startswith('#')]
prony_df = pd.DataFrame([list(map(float, row.split(','))) for row in prony_vals], columns=['rel_mod', 'rel_time'])

# Build reference relaxation curve from Prony series
prony_moduli = E0 * prony_df['rel_mod'].to_numpy()
prony_taus = prony_df['rel_time'].to_numpy()

E_prony = np.zeros_like(t)
for g, tau in zip(prony_moduli, prony_taus):
    E_prony += g * np.exp(-t / tau)

print(f"E0 = {E0:.2f} MPa, modes = {len(prony_moduli)}")
prony_df.head()


E0 = 1739.03 MPa, modes = 31


Unnamed: 0,rel_mod,rel_time
0,0.054565,0.01
1,0.019909,0.1
2,0.022591,1.0
3,0.011452,10.0
4,0.011624,100.0


## Fit RheoJAX models

In [4]:
gm = GeneralizedMaxwell(n_modes=6, modulus_type='tensile')
gm.fit(t, E_t, test_mode='relaxation', use_log_residuals=True)
gm_pred = gm.predict(t)
fm = FractionalMaxwellModel()
fm.fit(t, E_t, test_mode='relaxation', use_log_residuals=True)
fm_pred = fm.predict(t, test_mode='relaxation')

metrics = {
    'prony_ref_r2': r2_complex(E_t, E_prony),
    'gm_r2': gm.score(t, E_t),
    'fm_r2': r2_complex(E_t, fm_pred),
}
metrics


Auto-enabling multi-start optimization for very wide range (30.7 decades, 5 starts)


Starting least squares optimization | {'method': 'trf', 'n_params': 13, 'loss': 'linear', 'ftol': 1e-06, 'xtol': 1e-06, 'gtol': 1e-06}


Timer: optimization took 1.319089s


Convergence: reason=`xtol` termination condition is satisfied. | iterations=13 | final_cost=4.790065e+07 | time=1.319s | final_gradient_norm=1566957452.1387854


Starting least squares optimization | {'method': 'trf', 'n_params': 13, 'loss': 'linear', 'ftol': 1e-06, 'xtol': 1e-06, 'gtol': 1e-06}


Timer: optimization took 0.366344s


Convergence: reason=`xtol` termination condition is satisfied. | iterations=13 | final_cost=4.790065e+07 | time=0.366s | final_gradient_norm=1566957452.1387854


Starting least squares optimization | {'method': 'trf', 'n_params': 11, 'loss': 'linear', 'ftol': 1e-06, 'xtol': 1e-06, 'gtol': 1e-06}


Timer: optimization took 0.956776s


Convergence: reason=`ftol` termination condition is satisfied. | iterations=13 | final_cost=4.726767e+07 | time=0.957s | final_gradient_norm=4301842872.211341


Starting least squares optimization | {'method': 'trf', 'n_params': 9, 'loss': 'linear', 'ftol': 1e-06, 'xtol': 1e-06, 'gtol': 1e-06}


Timer: optimization took 0.964045s


Convergence: reason=`ftol` termination condition is satisfied. | iterations=16 | final_cost=4.726774e+07 | time=0.964s | final_gradient_norm=260843471.5355467


Starting least squares optimization | {'method': 'trf', 'n_params': 7, 'loss': 'linear', 'ftol': 1e-06, 'xtol': 1e-06, 'gtol': 1e-06}


Timer: optimization took 0.936230s


Convergence: reason=`ftol` termination condition is satisfied. | iterations=13 | final_cost=4.727053e+07 | time=0.936s | final_gradient_norm=2379781497.1322446


Starting least squares optimization | {'method': 'trf', 'n_params': 5, 'loss': 'linear', 'ftol': 1e-06, 'xtol': 1e-06, 'gtol': 1e-06}


Timer: optimization took 0.872018s


Convergence: reason=`ftol` termination condition is satisfied. | iterations=15 | final_cost=4.727792e+07 | time=0.872s | final_gradient_norm=50171392.943035275


Starting least squares optimization | {'method': 'trf', 'n_params': 3, 'loss': 'linear', 'ftol': 1e-06, 'xtol': 1e-06, 'gtol': 1e-06}


Timer: optimization took 0.518260s


Convergence: reason=Both `ftol` and `xtol` termination conditions are satisfied. | iterations=2 | final_cost=4.733564e+07 | time=0.518s | final_gradient_norm=0.08726020354349318


Element minimization: reducing from 6 to 1 modes


Auto-enabling multi-start optimization for very wide range (30.7 decades, 5 starts)


Starting least squares optimization | {'method': 'trf', 'n_params': 4, 'loss': 'linear', 'ftol': 1e-06, 'xtol': 1e-06, 'gtol': 1e-06}




Inner optimization loop hit iteration limit | {'inner_iterations': 100, 'actual_reduction': -1}


Timer: optimization took 1.332011s


Convergence: reason=Inner optimization loop exceeded maximum iterations. | iterations=1 | final_cost=6.616982e+05 | time=1.332s | final_gradient_norm=nan


NLSQ hit inner iteration limit; retrying with SciPy least_squares for stability.


{'prony_ref_r2': 0.9812662052520316,
 'gm_r2': 0.4343575559174251,
 'fm_r2': -22784.252813352563}

## Overlay reference vs RheoJAX fits

In [5]:
fig, ax = plt.subplots(figsize=(9, 6))
ax.loglog(t, E_t, 'o', label='Data', alpha=0.5)
ax.loglog(t, E_prony, '-', label='Prony reference')
ax.loglog(t, gm_pred, '--', label='Generalized Maxwell fit')
ax.loglog(t, fm_pred, ':', label='Fractional Maxwell fit')
ax.set_xlabel('Time (s)')
ax.set_ylabel('Relaxation modulus (MPa)')
ax.grid(True, which='both', ls='--', alpha=0.4)
ax.legend()
plt.show()


  plt.show()


## Residual summary

In [6]:
def mpe(y_true, y_pred):
    return float(np.mean(np.abs(y_true - y_pred) / np.maximum(np.abs(y_true), 1e-12)) * 100)

summary = pd.DataFrame([
    {'model': 'Prony reference', 'R2': metrics['prony_ref_r2'], 'MPE_%': mpe(E_t, E_prony)},
    {'model': 'Generalized Maxwell', 'R2': metrics['gm_r2'], 'MPE_%': mpe(E_t, gm_pred)},
    {'model': 'Fractional Maxwell', 'R2': metrics['fm_r2'], 'MPE_%': mpe(E_t, fm_pred)},
])
summary


Unnamed: 0,model,R2,MPE_%
0,Prony reference,0.981266,22.878967
1,Generalized Maxwell,0.434358,139.448414
2,Fractional Maxwell,-22784.252813,738.285177
