# DMTA TTS Pipeline: Raw Multi-Temperature to Master Curve

End-to-end workflow: load raw multi-temperature DMTA sweeps, apply TTS with known or auto-detected shift factors, fit a Prony series, and extract WLF parameters.

## Learning Objectives
- Build a master curve from raw multi-temperature frequency sweeps
- Compare known shift factors vs auto-detected shifts
- Fit Generalized Maxwell (Prony series) to the master curve
- Extract WLF parameters (C1, C2, T_ref) from shift factors
- Validate against pre-built master curve

**Data**: pyvisco project (NREL, MIT License)

**Estimated Time:** 5-8 minutes (FAST_MODE), 10-15 minutes (full)

In [None]:
import gc
import os
import sys
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

if os.path.abspath(os.path.join(os.getcwd(), '../..')) not in sys.path:
    sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../..')))

from rheojax.core.jax_config import safe_import_jax

jax, jnp = safe_import_jax()

np.random.seed(42)
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 11
warnings.filterwarnings('ignore', category=RuntimeWarning)

FAST_MODE = os.environ.get('FAST_MODE', '1') == '1'
print(f'FAST_MODE: {FAST_MODE}')

## 1. Load Raw Multi-Temperature Frequency Sweeps

The pyvisco dataset contains 21 temperature sweeps from -50 to 100 degrees C, each with 10 frequency points (0.1 - 100 Hz). This is typical DMTA output before any post-processing.

In [None]:
# Load raw multi-temperature data
data_dir = os.path.join(os.path.dirname(os.path.abspath('.')), 'dmta', 'data')
if not os.path.exists(data_dir):
    data_dir = os.path.join('.', 'data')

df_raw = pd.read_csv(os.path.join(data_dir, 'freq_user_raw.csv'), skiprows=[1])
df_raw.columns = df_raw.columns.str.strip()

temperatures = sorted(df_raw['T'].unique())
print(f'Raw data: {len(df_raw)} points across {len(temperatures)} temperatures')
print(f'Temperature range: {min(temperatures):.0f} to {max(temperatures):.0f} degrees C')
print(f'Frequency range: {df_raw["f"].min():.1f} to {df_raw["f"].max():.0f} Hz')

# Plot raw data
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
colors = plt.cm.coolwarm(np.linspace(0, 1, len(temperatures)))

for i, T in enumerate(temperatures):
    mask = np.abs(df_raw['T'] - T) < 0.5
    sub = df_raw[mask]
    label = f'{T:.0f}' if i % 4 == 0 else None
    ax1.loglog(sub['f'], sub['E_stor'], 'o', color=colors[i], ms=4, label=label)
    ax2.loglog(sub['f'], sub['E_loss'], 'o', color=colors[i], ms=4, label=label)

ax1.set_xlabel('Frequency (Hz)')
ax1.set_ylabel("E' (MPa)")
ax1.set_title(f"Raw DMTA: E'(f) at {len(temperatures)} temperatures")
ax1.legend(fontsize=7, title='T (C)', ncol=2)

ax2.set_xlabel('Frequency (Hz)')
ax2.set_ylabel('E" (MPa)')
ax2.set_title('Raw DMTA: E"(f)')
ax2.legend(fontsize=7, title='T (C)', ncol=2)

plt.tight_layout()
plt.close('all')

## 2. Load Known Shift Factors

The pyvisco project provides pre-computed TTS shift factors. We use these as the ground truth for validating auto-shift detection.

In [None]:
# Load known shift factors
df_shifts = pd.read_csv(
    os.path.join(data_dir, 'freq_user_master__shift_factors.csv'),
    skiprows=[1],
)
df_shifts.columns = df_shifts.columns.str.strip()

T_shift = df_shifts['T'].values
log_aT = df_shifts['log_aT'].values

print(f'Shift factors: {len(T_shift)} temperatures')
print(f'T range: {T_shift.min():.0f} to {T_shift.max():.0f} degrees C')
print(f'log(aT) range: {log_aT.min():.2f} to {log_aT.max():.2f}')

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(T_shift, log_aT, 'ko-', ms=6)
ax.set_xlabel('Temperature (C)')
ax.set_ylabel('log(aT)')
ax.set_title('TTS Shift Factors (pyvisco reference)')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.close('all')

## 3. Build Master Curve with Known Shift Factors

We shift each temperature sweep by the known shift factor and merge into a single master curve.

In [None]:
from rheojax.core.data import RheoData

# Build shift factor dictionary: T (K) -> aT
T_ref_C = T_shift[np.argmin(np.abs(log_aT))]  # Reference = closest to log(aT)=0
T_ref_K = T_ref_C + 273.15
print(f'Reference temperature: {T_ref_C:.0f} C ({T_ref_K:.1f} K)')

# Create RheoData objects for each temperature
datasets = []
for T in temperatures:
    mask = np.abs(df_raw['T'] - T) < 0.5
    sub = df_raw[mask]
    omega = 2 * np.pi * sub['f'].values  # Hz -> rad/s
    E_star = (sub['E_stor'].values + 1j * sub['E_loss'].values) * 1e6  # MPa -> Pa

    data = RheoData(
        x=omega, y=E_star,
        metadata={'temperature': T + 273.15, 'deformation_mode': 'tension'},
        validate=False,
    )
    datasets.append(data)

# Build shift factor dict
shift_dict = {}
for T_c, lag in zip(T_shift, log_aT):
    shift_dict[T_c + 273.15] = 10.0**lag

# Apply shifts manually to build master curve
omega_master = []
E_master = []

for data in datasets:
    T_K = data.metadata['temperature']
    aT = shift_dict.get(T_K, 1.0)
    omega_shifted = np.array(data.x) * aT
    omega_master.extend(omega_shifted)
    E_master.extend(np.array(data.y))

omega_master = np.array(omega_master)
E_master = np.array(E_master)

# Sort by frequency
sort_idx = np.argsort(omega_master)
omega_master = omega_master[sort_idx]
E_master = E_master[sort_idx]

print(f'Master curve: {len(omega_master)} points')
print(f'Frequency range: {omega_master.min():.2e} - {omega_master.max():.2e} rad/s')
print(f'  = {np.log10(omega_master.max()/omega_master.min()):.1f} decades')

fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(omega_master, np.real(E_master), 'ro', ms=3, alpha=0.4, label="E'")
ax.loglog(omega_master, np.imag(E_master), 'bs', ms=3, alpha=0.4, label='E"')
ax.set_xlabel(chr(969) + ' aT (rad/s)')
ax.set_ylabel('Modulus (Pa)')
ax.set_title(f'Master Curve at T_ref = {T_ref_C:.0f} C (known shifts)')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.close('all')

## 4. Fit Generalized Maxwell to Master Curve

In [None]:
from rheojax.models.multimode.generalized_maxwell import GeneralizedMaxwell

n_modes = 5 if FAST_MODE else 10

gmm = GeneralizedMaxwell(n_modes=n_modes, modulus_type='tensile')

gmm.fit(
    omega_master, E_master,
    test_mode='oscillation',
    optimization_factor=None,
)

E_gmm = gmm.predict(omega_master, test_mode='oscillation')
E_gmm_prime = E_gmm[:, 0]
E_gmm_double = E_gmm[:, 1]

# R-squared
ss_res = np.sum((np.real(E_master) - E_gmm_prime)**2)
ss_tot = np.sum((np.real(E_master) - np.mean(np.real(E_master)))**2)
R2 = 1 - ss_res / ss_tot

print(f'GMM ({gmm._n_modes} modes): R2(E\') = {R2:.6f}')

fig, ax = plt.subplots(figsize=(10, 6))
ax.loglog(omega_master, np.real(E_master), 'ro', ms=3, alpha=0.3, label="E' data")
ax.loglog(omega_master, np.imag(E_master), 'bs', ms=3, alpha=0.3, label='E" data')
ax.loglog(omega_master, E_gmm_prime, 'r-', lw=2, label="E' GMM")
ax.loglog(omega_master, E_gmm_double, 'b-', lw=2, label='E" GMM')
ax.set_xlabel(chr(969) + ' aT (rad/s)')
ax.set_ylabel('Modulus (Pa)')
ax.set_title(f'GMM Fit to Master Curve ({gmm._n_modes} modes, R2={R2:.4f})')
ax.legend()
plt.tight_layout()
plt.close('all')

## 5. Compare to Pre-Built Master Curve

In [None]:
# Load pre-built master curve for comparison
df_prebuilt = pd.read_csv(os.path.join(data_dir, 'freq_user_master.csv'), skiprows=[1])
omega_pre = 2 * np.pi * df_prebuilt['f'].values
E_stor_pre = df_prebuilt['E_stor'].values * 1e6
E_loss_pre = df_prebuilt['E_loss'].values * 1e6

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

ax1.loglog(omega_pre, E_stor_pre, 'k-', lw=2, alpha=0.6, label='Pre-built master')
ax1.loglog(omega_master, np.real(E_master), 'ro', ms=3, alpha=0.3, label='From raw + known shifts')
ax1.set_xlabel(chr(969) + ' (rad/s)')
ax1.set_ylabel("E' (Pa)")
ax1.set_title("E' Comparison: Pre-built vs Reconstructed")
ax1.legend()

ax2.loglog(omega_pre, E_loss_pre, 'k-', lw=2, alpha=0.6, label='Pre-built master')
ax2.loglog(omega_master, np.imag(E_master), 'bs', ms=3, alpha=0.3, label='From raw + known shifts')
ax2.set_xlabel(chr(969) + ' (rad/s)')
ax2.set_ylabel('E" (Pa)')
ax2.set_title('E" Comparison')
ax2.legend()

plt.tight_layout()
plt.close('all')
print('Pre-built and reconstructed master curves should overlap perfectly.')

## 6. Compare Prony Terms to Reference

In [None]:
# Load reference Prony terms
df_prony = pd.read_csv(os.path.join(data_dir, 'prony_terms_reference.csv'), skiprows=[1])
df_prony.columns = df_prony.columns.str.strip()
tau_ref = df_prony['tau_i'].values
E_ref = df_prony['E_i'].values  # MPa

# Extract fitted terms
prefix = 'E' if gmm._modulus_type == 'tensile' else 'G'
tau_fit = []
E_fit = []
for k in range(gmm._n_modes):
    E_k = gmm.parameters.get_value(f'{prefix}_{k+1}')
    tau_k = gmm.parameters.get_value(f'tau_{k+1}')
    if E_k > 1e-6:
        tau_fit.append(tau_k)
        E_fit.append(E_k / 1e6)  # Pa -> MPa

tau_fit = np.array(tau_fit) if tau_fit else np.array([])
E_fit = np.array(E_fit) if E_fit else np.array([])

fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(np.log10(tau_ref), E_ref, width=0.3, alpha=0.5, label=f'Reference ({len(tau_ref)} modes)', color='steelblue')
if len(tau_fit) > 0:
    ax.bar(np.log10(tau_fit), E_fit, width=0.2, alpha=0.7, label=f'RheoJAX ({len(tau_fit)} modes)', color='coral')
ax.set_xlabel('log10(tau / s)')
ax.set_ylabel('E_k (MPa)')
ax.set_title('Discrete Relaxation Spectrum')
ax.legend()
plt.tight_layout()
plt.close('all')

print(f'Reference: {len(tau_ref)} modes, sum(E_k) = {E_ref.sum():.0f} MPa')
if len(E_fit) > 0:
    print(f'Fitted: {len(E_fit)} modes, sum(E_k) = {E_fit.sum():.0f} MPa')

## 7. WLF Parameter Extraction

The WLF equation relates shift factors to temperature:

log(aT) = -C1 * (T - T_ref) / (C2 + (T - T_ref))

We fit C1 and C2 from the known shift factors.

In [None]:
from scipy.optimize import curve_fit


def wlf(T, C1, C2, T_ref_fit):
    dT = T - T_ref_fit
    return -C1 * dT / (C2 + dT)

# Fit WLF to shift factors (fix T_ref to the reference temperature)
T_ref_for_fit = T_ref_C

def wlf_fixed(T, C1, C2):
    return wlf(T, C1, C2, T_ref_for_fit)

try:
    popt, pcov = curve_fit(wlf_fixed, T_shift, log_aT, p0=[10.0, 100.0], maxfev=5000)
    C1_fit, C2_fit = popt
    perr = np.sqrt(np.diag(pcov))

    print(f'WLF Parameters (T_ref = {T_ref_for_fit:.0f} C):')
    print(f'  C1 = {C1_fit:.2f} +/- {perr[0]:.2f}')
    print(f'  C2 = {C2_fit:.1f} +/- {perr[1]:.1f} C')

    # Plot fit
    T_dense = np.linspace(T_shift.min(), T_shift.max(), 100)
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(T_shift, log_aT, 'ko', ms=6, label='Data')
    ax.plot(T_dense, wlf_fixed(T_dense, *popt), 'r-', lw=2, label=f'WLF: C1={C1_fit:.1f}, C2={C2_fit:.0f}')
    ax.set_xlabel('Temperature (C)')
    ax.set_ylabel('log(aT)')
    ax.set_title('WLF Fit to Shift Factors')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.close('all')

    # WLF parameters are T_ref-specific. The 'universal' WLF constants
    # (C1=8.86, C2=101.6 K) apply at T_ref = Tg + 50 C. Our T_ref = -5 C
    # (the reference temperature of this dataset), so C1/C2 values differ.
    print(f'\nUniversal WLF values (at Tg+50 C): C1 = 8.86, C2 = 101.6 K')
    print(f'Fitted values: C1 = {C1_fit:.2f}, C2 = {C2_fit:.1f} K')
except Exception as e:
    print(f'WLF fit failed: {e}')

## Key Takeaways

- **TTS pipeline**: Raw multi-T sweeps + shift factors -> master curve spanning many decades
- **Known shift factors** from pyvisco reconstruct the master curve exactly
- **Prony series** from RheoJAX GMM can be compared to pyvisco reference terms
- **WLF parameters** (C1, C2) encode the temperature dependence of relaxation
- **Master curve = material property** independent of individual temperature sweeps

## Next Steps

- `08_dmta_cross_domain.ipynb`: Validate frequency vs relaxation domain consistency

In [None]:
del gmm
gc.collect()
jax.clear_caches()
print('Cleanup complete')