# DMTA Model Selection: Which Model for Your Material?

This notebook demonstrates how to compare multiple RheoJAX models on DMTA data and select the best one for your application.

## Learning Objectives
- Query DMTA-compatible models from the registry
- Fit multiple models to the same DMTA dataset
- Compare fit quality using residual analysis
- Understand when to use each model family

**Estimated Time:** 10 minutes

In [1]:
import gc
import os
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np

if os.path.abspath(os.path.join(os.getcwd(), '../..')) not in sys.path:
    sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../..')))

from rheojax.core.jax_config import safe_import_jax

jax, jnp = safe_import_jax()

np.random.seed(42)
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 11
warnings.filterwarnings('ignore', category=RuntimeWarning)

FAST_MODE = os.environ.get('FAST_MODE', '1') == '1'
print(f'FAST_MODE: {FAST_MODE}')

FAST_MODE: True


## 1. Querying DMTA-Compatible Models

The `ModelRegistry` can filter models by protocol and deformation mode. All models with `Protocol.OSCILLATION` and `DeformationMode.TENSION` support DMTA data.

In [2]:
from rheojax.core.inventory import Protocol
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode

# All DMTA-compatible models
dmta_models = ModelRegistry.find(
    protocol=Protocol.OSCILLATION,
    deformation_mode=DeformationMode.TENSION,
)
print(f'{len(dmta_models)} models support DMTA (oscillation + tension)')

# Group by family
families = {}
for name in sorted(dmta_models):
    info = ModelRegistry.get_info(name)
    family = info.family if hasattr(info, 'family') and info.family else 'other'
    families.setdefault(family, []).append(name)

for family, models in sorted(families.items()):
    print(f'\n  {family} ({len(models)}):')
    for m in models:
        print(f'    - {m}')

49 models support DMTA (oscillation + tension)

  other (49):
    - dmt_local
    - fikh
    - fluidity_local
    - fluidity_nonlocal
    - fluidity_saramito_local
    - fmlikh
    - fractional_burgers
    - fractional_jeffreys
    - fractional_kelvin_voigt
    - fractional_kv_zener
    - fractional_maxwell_gel
    - fractional_maxwell_liquid
    - fractional_maxwell_model
    - fractional_poynting_thomson
    - fractional_zener_ll
    - fractional_zener_sl
    - fractional_zener_ss
    - generalized_maxwell
    - giesekus
    - giesekus_multi
    - giesekus_multimode
    - giesekus_single
    - hebraud_lequeux
    - hvm
    - hvm_local
    - hvnm
    - hvnm_local
    - itt_mct_isotropic
    - itt_mct_schematic
    - lattice_epm
    - maxwell
    - mikh
    - ml_ikh
    - sgr_conventional
    - sgr_generic
    - springpot
    - stz_conventional
    - tensorial_epm
    - tnt
    - tnt_cates
    - tnt_loop_bridge
    - tnt_multi_species
    - tnt_single_mode
    - tnt_sticky_rouse
    - 

## 2. Synthetic DMTA Dataset

We generate DMTA data from a moderately broad glass transition — challenging enough to distinguish model capabilities.

In [3]:
# Generate Zener-like E*(ω) with moderate broadening
omega = np.logspace(-2, 3, 60 if FAST_MODE else 100)

# Material: amorphous polymer near Tg
G_e = 1e5    # Rubbery modulus (Pa)
G_g = 5e8    # Glassy modulus (Pa)
tau = 0.05   # Relaxation time (s)
beta = 0.6   # Cole-Davidson broadening
nu = 0.40    # Poisson's ratio (semicrystalline)
factor = 2 * (1 + nu)  # = 2.8

# Cole-Davidson distribution
iw_tau = 1j * omega * tau
G_star = G_e + (G_g - G_e) * (1 - 1 / (1 + iw_tau)**beta)
E_star_true = factor * G_star

# Add 1% noise
noise = 1 + 0.01 * np.random.randn(len(omega)) + 1j * 0.01 * np.random.randn(len(omega))
E_star = E_star_true * noise

print(f'E_rubbery = {G_e * factor:.2e} Pa')
print(f'E_glassy  = {G_g * factor:.2e} Pa')
print(f'Frequency range: {omega[0]:.0e} - {omega[-1]:.0e} rad/s')

E_rubbery = 2.80e+05 Pa
E_glassy  = 1.40e+09 Pa
Frequency range: 1e-02 - 1e+03 rad/s


## 3. Multi-Model Comparison

We fit three models of increasing complexity to the same DMTA data:
1. **Maxwell** (2 params) — single relaxation time, no equilibrium modulus
2. **Zener** (3 params) — single relaxation with rubbery plateau
3. **Fractional Zener SS** (4 params) — broad relaxation with plateau

In [4]:
from rheojax.models.classical.maxwell import Maxwell
from rheojax.models.classical.zener import Zener
from rheojax.models.fractional.fractional_zener_ss import FractionalZenerSolidSolid

results = {}

# Model 1: Maxwell
m1 = Maxwell()
m1.fit(omega, E_star, test_mode='oscillation', deformation_mode='tension', poisson_ratio=nu)
E1 = m1.predict(omega, test_mode='oscillation')
results['Maxwell (2p)'] = {'model': m1, 'pred': E1, 'n_params': 2, 'color': 'blue'}
del m1; gc.collect()

# Model 2: Zener
m2 = Zener()
m2.fit(omega, E_star, test_mode='oscillation', deformation_mode='tension', poisson_ratio=nu)
E2 = m2.predict(omega, test_mode='oscillation')
results['Zener (3p)'] = {'model': m2, 'pred': E2, 'n_params': 3, 'color': 'green'}
del m2; gc.collect()

# Model 3: Fractional Zener SS
m3 = FractionalZenerSolidSolid()
m3.fit(omega, E_star, test_mode='oscillation', deformation_mode='tension', poisson_ratio=nu)
E3 = m3.predict(omega, test_mode='oscillation')
alpha = m3.parameters.get_value('alpha')
results[f'FZSS (4p, α={alpha:.2f})'] = {'model': m3, 'pred': E3, 'n_params': 4, 'color': 'red'}
del m3; gc.collect()

print('Fitting complete for all 3 models')

INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=2 | loss=linear | ftol=1.0000e-06 | xtol=1.0000e-06 | gtol=1.0000e-06


PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=0.815927s


INFO:nlsq.least_squares:Convergence reason=`ftol` termination condition is satisfied. | iterations=9 | final_cost=47.5979 | elapsed=0.816s | final_gradient_norm=0.9406


INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=3 | loss=linear | ftol=1.0000e-06 | xtol=1.0000e-06 | gtol=1.0000e-06


PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=0.722469s


INFO:nlsq.least_squares:Convergence reason=`ftol` termination condition is satisfied. | iterations=21 | final_cost=1.9819 | elapsed=0.722s | final_gradient_norm=0.0020


INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=4 | loss=linear | ftol=1.0000e-06 | xtol=1.0000e-06 | gtol=1.0000e-06


PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=1.146388s


INFO:nlsq.least_squares:Convergence reason=`ftol` termination condition is satisfied. | iterations=20 | final_cost=1.8925 | elapsed=1.146s | final_gradient_norm=0.0014


Fitting complete for all 3 models


In [5]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# E' comparison
ax = axes[0, 0]
ax.loglog(omega, np.real(E_star), 'ko', ms=3, alpha=0.4, label='Data')
for name, r in results.items():
    ax.loglog(omega, np.real(r['pred']), '-', color=r['color'], lw=2, label=name)
ax.set_xlabel('ω (rad/s)')
ax.set_ylabel("E' (Pa)")
ax.set_title("Storage Modulus E'(ω)")
ax.legend(fontsize=8)

# E'' comparison
ax = axes[0, 1]
ax.loglog(omega, np.imag(E_star), 'ko', ms=3, alpha=0.4, label='Data')
for name, r in results.items():
    ax.loglog(omega, np.imag(r['pred']), '-', color=r['color'], lw=2, label=name)
ax.set_xlabel('ω (rad/s)')
ax.set_ylabel('E" (Pa)')
ax.set_title('Loss Modulus E"(ω)')
ax.legend(fontsize=8)

# tan(δ) comparison
ax = axes[1, 0]
tan_d_data = np.imag(E_star) / np.real(E_star)
ax.semilogx(omega, tan_d_data, 'ko', ms=3, alpha=0.4, label='Data')
for name, r in results.items():
    tan_d = np.imag(r['pred']) / np.real(r['pred'])
    ax.semilogx(omega, tan_d, '-', color=r['color'], lw=2, label=name)
ax.set_xlabel('ω (rad/s)')
ax.set_ylabel('tan(δ)')
ax.set_title('Loss Tangent')
ax.legend(fontsize=8)

# Residual comparison
ax = axes[1, 1]
for name, r in results.items():
    residual_pct = 100 * (np.abs(r['pred']) - np.abs(E_star)) / np.abs(E_star)
    ax.semilogx(omega, residual_pct, '-', color=r['color'], lw=1.5, label=name)
ax.axhline(0, color='gray', ls='--', alpha=0.5)
ax.set_xlabel('ω (rad/s)')
ax.set_ylabel('Residual (%)')
ax.set_title('Relative Residual')
ax.legend(fontsize=8)
ax.set_ylim(-30, 30)

plt.tight_layout()
plt.close('all')

In [6]:
print(f'{"Model":<25s}  {"N_params":>8}  {"R²":>10}  {"Max |Resid|%":>12}')
print('-' * 60)

for name, r in results.items():
    residual = np.abs(E_star) - np.abs(r['pred'])
    ss_res = np.sum(residual**2)
    ss_tot = np.sum((np.abs(E_star) - np.mean(np.abs(E_star)))**2)
    R2 = 1 - ss_res / ss_tot
    max_resid = 100 * np.max(np.abs(residual) / np.abs(E_star))
    print(f'{name:<25s}  {r["n_params"]:>8d}  {R2:>10.6f}  {max_resid:>12.1f}%')

Model                      N_params          R²  Max |Resid|%
------------------------------------------------------------
Maxwell (2p)                      2   -0.762640         100.0%
Zener (3p)                        3    0.984141          18.1%
FZSS (4p, α=1.00)                 4   -1.031098      188889.6%


## 5. Model Selection with Real DMTA Data

We repeat the multi-model comparison on real polymer data from the pyvisco project. This tests whether model rankings hold on actual experimental data with real noise and measurement artifacts.

In [7]:
import pandas as pd

# Load real frequency-domain master curve
data_dir = os.path.join(os.path.dirname(os.path.abspath('.')), 'dmta', 'data')
if not os.path.exists(data_dir):
    data_dir = os.path.join('.', 'data')

df_master = pd.read_csv(os.path.join(data_dir, 'freq_user_master.csv'), skiprows=[1])
omega_real = 2 * np.pi * df_master['f'].values
E_star_real = (df_master['E_stor'].values + 1j * df_master['E_loss'].values) * 1e6  # MPa -> Pa

print(f'Real data: {len(omega_real)} pts, {np.log10(omega_real.max()/omega_real.min()):.1f} decades')

Real data: 206 pts, 26.0 decades


In [8]:
# Fit same models to real data
results_real = {}

# Maxwell
m1 = Maxwell()
m1.fit(omega_real, E_star_real, test_mode='oscillation', deformation_mode='tension', poisson_ratio=0.35)
E1r = m1.predict(omega_real, test_mode='oscillation')
results_real['Maxwell (2p)'] = {'pred': E1r, 'n_params': 2, 'color': 'blue'}
del m1; gc.collect()

# Zener
m2 = Zener()
m2.fit(omega_real, E_star_real, test_mode='oscillation', deformation_mode='tension', poisson_ratio=0.35)
E2r = m2.predict(omega_real, test_mode='oscillation')
results_real['Zener (3p)'] = {'pred': E2r, 'n_params': 3, 'color': 'green'}
del m2; gc.collect()

# FZSS
m3 = FractionalZenerSolidSolid()
m3.fit(omega_real, E_star_real, test_mode='oscillation', deformation_mode='tension', poisson_ratio=0.35)
E3r = m3.predict(omega_real, test_mode='oscillation')
alpha_r = m3.parameters.get_value('alpha')
results_real[f'FZSS (4p, alpha={alpha_r:.2f})'] = {'pred': E3r, 'n_params': 4, 'color': 'red'}
del m3; gc.collect()

if not FAST_MODE:
    # Additional models in full mode
    from rheojax.models.fractional.fractional_maxwell_model import (
        FractionalMaxwellModel,
    )
    from rheojax.models.multimode.generalized_maxwell import GeneralizedMaxwell

    m4 = FractionalMaxwellModel()
    m4.fit(omega_real, E_star_real, test_mode='oscillation', deformation_mode='tension', poisson_ratio=0.35)
    E4r = m4.predict(omega_real, test_mode='oscillation')
    results_real['FMM (3p)'] = {'pred': E4r, 'n_params': 3, 'color': 'purple'}
    del m4; gc.collect()

    gmm = GeneralizedMaxwell(n_modes=10, modulus_type='tensile')
    gmm.fit(omega_real, E_star_real, test_mode='oscillation', optimization_factor=1.5)
    E5r = gmm.predict(omega_real, test_mode='oscillation')
    if E5r.ndim == 2 and E5r.shape[1] == 2:
        E5r = E5r[:, 0] + 1j * E5r[:, 1]
    results_real[f'GMM ({gmm._n_modes}m)'] = {'pred': E5r, 'n_params': gmm._n_modes * 2 + 1, 'color': 'orange'}
    del gmm; gc.collect()

print(f'Fitted {len(results_real)} models to real data')

INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=2 | loss=linear | ftol=1.0000e-06 | xtol=1.0000e-06 | gtol=1.0000e-06


PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=0.359678s


INFO:nlsq.least_squares:Convergence reason=`ftol` termination condition is satisfied. | iterations=42 | final_cost=176.3336 | elapsed=0.360s | final_gradient_norm=10.7893


INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=3 | loss=linear | ftol=1.0000e-06 | xtol=1.0000e-06 | gtol=1.0000e-06


PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=0.386055s


INFO:nlsq.least_squares:Convergence reason=`ftol` termination condition is satisfied. | iterations=56 | final_cost=115.5200 | elapsed=0.386s | final_gradient_norm=0.2879


INFO:nlsq.least_squares:Starting least squares optimization method=trf | n_params=4 | loss=linear | ftol=1.0000e-06 | xtol=1.0000e-06 | gtol=1.0000e-06


PERFORMANCE:nlsq.least_squares:Timer: optimization elapsed=0.382746s


INFO:nlsq.least_squares:Convergence reason=`xtol` termination condition is satisfied. | iterations=6 | final_cost=46.7768 | elapsed=0.383s | final_gradient_norm=2.8145e+09


Fitted 3 models to real data


In [9]:
# Results table for real data
print(f'{"Model":<30s}  {"N_params":>8}  {"R2":>10}  {"Max |Resid|%":>12}')
print('-' * 65)

E_stor_real = np.real(E_star_real)
E_loss_real = np.imag(E_star_real)

for name, r in results_real.items():
    pred = r['pred']
    pred_p = np.real(pred)
    pred_pp = np.imag(pred)
    res_p = E_stor_real - pred_p
    res_pp = E_loss_real - pred_pp
    ss_res = np.sum(res_p**2) + np.sum(res_pp**2)
    ss_tot = np.sum((E_stor_real - np.mean(E_stor_real))**2) + np.sum((E_loss_real - np.mean(E_loss_real))**2)
    R2 = 1 - ss_res / ss_tot
    max_resid = 100 * np.max(np.abs(np.abs(E_star_real) - np.abs(pred)) / np.abs(E_star_real))
    print(f'{name:<30s}  {r["n_params"]:>8d}  {R2:>10.6f}  {max_resid:>12.1f}%')

Model                           N_params          R2  Max |Resid|%
-----------------------------------------------------------------
Maxwell (2p)                           2   -0.238374         100.0%
Zener (3p)                             3   -0.157567          96.1%
FZSS (4p, alpha=0.20)                  4  -34760641858105867939267533819920884252453293305932066877165466292469709743583308945985157950019392091123520559902621696.000000  4223773651439806974836230136703205316762174520446264882449350656.0%


In [10]:
# Residual analysis on real data
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# E' comparison
ax = axes[0]
ax.loglog(omega_real, np.real(E_star_real), 'ko', ms=3, alpha=0.3, label='Data')
for name, r in results_real.items():
    ax.loglog(omega_real, np.real(r['pred']), '-', color=r['color'], lw=1.5, label=name)
ax.set_xlabel(chr(969) + ' (rad/s)')
ax.set_ylabel("E' (Pa)")
ax.set_title("E' on Real Data")
ax.legend(fontsize=7)

# E'' comparison
ax = axes[1]
ax.loglog(omega_real, np.imag(E_star_real), 'ko', ms=3, alpha=0.3, label='Data')
for name, r in results_real.items():
    ax.loglog(omega_real, np.imag(r['pred']), '-', color=r['color'], lw=1.5, label=name)
ax.set_xlabel(chr(969) + ' (rad/s)')
ax.set_ylabel('E" (Pa)')
ax.set_title('E" on Real Data')
ax.legend(fontsize=7)

# Residuals
ax = axes[2]
for name, r in results_real.items():
    resid = 100 * (np.real(r['pred']) - np.real(E_star_real)) / np.real(E_star_real)
    ax.semilogx(omega_real, resid, '-', color=r['color'], lw=1, alpha=0.7, label=name)
ax.axhline(0, color='gray', ls='--', alpha=0.5)
ax.set_xlabel(chr(969) + ' (rad/s)')
ax.set_ylabel("E' Residual (%)")
ax.set_title('Residual Analysis')
ax.legend(fontsize=7)

plt.tight_layout()
plt.close('all')

## 6. Model Selection Decision Guide

| Your Data Shows | Recommended Model | Why |
|----------------|-------------------|-----|
| Narrow E'' peak, single tau | Zener | Minimal parameters, physical |
| Broad E'' peak (glass transition) | Fractional Zener SS | alpha captures breadth compactly |
| Multi-decade master curve | Generalized Maxwell | Prony series, FEM export |
| T-dependent measurements | VLBVariant or HVM | Built-in Arrhenius kinetics |
| Vitrimer with Tv transition | HVMLocal | TST bond exchange kinetics |
| Quick baseline | Maxwell | 2 parameters, no plateau |

**Rule of thumb**: Start with Fractional Zener SS for amorphous polymers near Tg. Use Generalized Maxwell for master curves destined for FEM software (ANSYS, Abaqus).

## Key Takeaways

- **49 models** in RheoJAX support DMTA via `deformation_mode='tension'`
- **Real data confirms** that broad glass transitions need fractional or multi-mode models
- **Classical Zener** provides a good baseline but misses breadth of real transitions
- **Maxwell** lacks a rubbery plateau and is unsuitable for solid-state DMTA
- **Model ranking may shift** between synthetic and real data due to measurement artifacts

## Next Steps

- `07_dmta_tts_pipeline.ipynb`: Build master curve from raw multi-temperature data
- `08_dmta_cross_domain.ipynb`: Cross-domain consistency analysis

In [11]:
del results
gc.collect()
jax.clear_caches()
print('Cleanup complete')

Cleanup complete
