# Chapter IV: Improving Training of PINNs

This notebook covers techniques to improve the training of Physics-Informed Neural Networks (PINNs).

In [None]:
import sys
import os
sys.path.append('examples/allen_cahn')

# Option to run code or load saved results
RUN_CODE = True  # Set to False to load saved results

# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import scienceplots  # optional, requires latex
plt.style.use(['science', 'grid'])
import pickle
import pandas as pd
from allen_cahn_train import train_allen_cahn
from scipy.io import loadmat

## IV_1: Classical Neural Networks Techniques

This section discusses learning rate schedules, regularization, normalization, and activation functions.

In [None]:
# Placeholder for IV_1: Classical techniques
# Example: Different activation functions
if RUN_CODE:
    config_base = {
        "fourier_features": False,
        "net_type": "PINN",
        "n_domain": 100**2,
        "n_iters": 5000,
        "seed": 0,
    }
    configs = [
        {**config_base, "activations": "tanh"},
        {**config_base, "activations": "sin"},
        {**config_base, "activations": "relu"},
    ]
    results_iv1 = [train_allen_cahn(c) for c in configs]
    os.makedirs('results/IV_ImprovingTraining/iv1_classical', exist_ok=True)
    with open('results/IV_ImprovingTraining/iv1_classical/results.pkl', 'wb') as f:
        pickle.dump([{k: v for k, v in r.items() if k != "model"} for r in results_iv1], f)
else:
    with open('results/IV_ImprovingTraining/iv1_classical/results.pkl', 'rb') as f:
        results_iv1 = pickle.load(f)

# Plot losses
plt.figure()
for i, res in enumerate(results_iv1):
    plt.plot(res['pde_loss'], label=f'Activation {i}')
plt.legend()
plt.title('PDE Loss for Different Activations')
plt.yscale('log')
plt.savefig('results/IV_ImprovingTraining/iv1_classical/loss_plot.png')
plt.show()

## IV_2: Fourier Feature Embedding

Fourier features help in learning high-frequency components.

In [None]:
if RUN_CODE:
    config_spinn = {
        "fourier_features": True,
        "n_fourier_features": 128,
        "sigma": 10,
        "net_type": "SPINN",
        "mlp_type": "mlp",
        "activations": "sin",
        "n_domain": 150**2,
        "n_iters": 10000,  # Reduced
        "rank": 64,
        "seed": 0,
    }

    # Train SPINN with and without Fourier features
    data_spinn_fourier = train_allen_cahn(config_spinn)
    data_spinn_no_fourier = train_allen_cahn({**config_spinn, "fourier_features": False})

    # Save fields (exclude model to avoid pickling issues)
    os.makedirs('results/IV_ImprovingTraining/iv2_fourier', exist_ok=True)
    with open('results/IV_ImprovingTraining/iv2_fourier/fields.pkl', 'wb') as f:
        pickle.dump({
            'fourier': {k: v for k, v in data_spinn_fourier.items() if k != "model"},
            'no_fourier': {k: v for k, v in data_spinn_no_fourier.items() if k != "model"}
        }, f)
else:
    with open('results/IV_ImprovingTraining/iv2_fourier/fields.pkl', 'rb') as f:
        saved = pickle.load(f)
        data_spinn_fourier = saved['fourier']
        data_spinn_no_fourier = saved['no_fourier']

# Load test data
def gen_testdata():
    data = loadmat("examples/dataset/Allen_Cahn.mat")
    t = data["t"]
    x = data["x"]
    u = data["u"]

    xx, tt = np.meshgrid(x, t, indexing="ij")
    y = u.flatten()[:, None]

    return y, xx, tt, u

y, xx, tt, u_true = gen_testdata()
u_pred_fourier = data_spinn_fourier["u_pred"]
u_pred_no_fourier = data_spinn_no_fourier["u_pred"]

# Plot
fig, ax = plt.subplots(1, 3, figsize=(12, 4))

vmin = np.nanmin([u_true, u_pred_fourier, u_pred_no_fourier])
vmax = np.nanmax([u_true, u_pred_fourier, u_pred_no_fourier])

ax[0].pcolor(tt, xx, u_true, vmin=vmin, vmax=vmax)
ax[0].set_title("Ground Truth")
ax[0].set_ylabel("x")
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[0].set_aspect(1/2)

ax[1].pcolor(tt, xx, u_pred_no_fourier, vmin=vmin, vmax=vmax)
ax[1].set_title("Prediction (No Fourier)")
ax[1].set_xticks([])
ax[1].set_yticks([])
ax[1].set_aspect(1/2)

ax[2].pcolor(tt, xx, u_pred_fourier, vmin=vmin, vmax=vmax)
ax[2].set_title("Prediction (With Fourier)")
ax[2].set_xlabel("t")
ax[2].set_xticks([])
ax[2].set_yticks([])
ax[2].set_aspect(1/2)

plt.tight_layout()
plt.savefig('results/IV_ImprovingTraining/iv2_fourier/prediction_plot.png')
plt.show()

# Plot loss
plt.figure()
plt.plot(data_spinn_no_fourier["pde_loss"], label='No Fourier')
plt.plot(data_spinn_fourier["pde_loss"], label='With Fourier')
plt.legend()
plt.title('PDE Loss')
plt.yscale('log')
plt.savefig('results/IV_ImprovingTraining/iv2_fourier/loss_plot.png')
plt.show()

## IV_3: Adaptive Sampling and Attention

Adaptive sampling resamples points based on residuals. Attention focuses on important features.