# 02 — PINN Sod Experiment

Visualização das curvas de treinamento da PINN e comparação com a solução analítica do problema de Sod.

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path('..').resolve()
SRC_PATH = PROJECT_ROOT / 'src'
if str(SRC_PATH) not in sys.path:
    sys.path.insert(0, str(SRC_PATH))

In [None]:
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from omegaconf import OmegaConf

from riemann_ml.core.euler1d import StatePrim
from riemann_ml.exact.sod_exact import sod_exact_profile
from riemann_ml.ml.pinn.model import PINN, conservative_to_primitive

GAMMA = 1.4
CONFIG_PATH = Path('src/riemann_ml/configs/pinn.yaml')
CHECKPOINT_DIR = Path('data/artifacts/pinn/checkpoints')
OUTPUT_DIR = Path('data/artifacts/pinn/outputs')

print('Carregando histórico de perdas...')
history_path = OUTPUT_DIR / 'loss_history.json'
if history_path.exists():
    history = json.loads(history_path.read_text())
    steps = [entry['step'] for entry in history]
    total = [entry['loss_total'] for entry in history]
    pde = [entry['loss_pde'] for entry in history]
    ic = [entry['loss_ic'] for entry in history]
    bc = [entry['loss_bc'] for entry in history]

    plt.figure(figsize=(8, 5))
    plt.plot(steps, total, label='Total')
    plt.plot(steps, pde, label='PDE', alpha=0.7)
    plt.plot(steps, ic, label='IC', alpha=0.7)
    plt.plot(steps, bc, label='BC', alpha=0.7)
    plt.xlabel('Passo')
    plt.ylabel('Loss')
    plt.yscale('log')
    plt.legend()
    plt.grid(True)
    plt.title('Curvas de treinamento da PINN')
    plt.show()
else:
    print('Histórico não encontrado:', history_path)

print('Gerando comparação final...')
cfg = OmegaConf.load(CONFIG_PATH)
model = PINN(OmegaConf.to_container(cfg.model, resolve=True))
model(tf.zeros((1, 2), dtype=tf.float32))
latest = tf.train.latest_checkpoint(str(CHECKPOINT_DIR))
if latest:
    tf.train.Checkpoint(model=model).restore(latest).expect_partial()
    x = np.linspace(cfg.domain.x_min, cfg.domain.x_max, 256, dtype=np.float32)
    t = np.full_like(x, cfg.domain.t_max, dtype=np.float32)
    preds = model.predict_conservative(tf.convert_to_tensor(x[:, None]), tf.convert_to_tensor(t[:, None]), training=False)
    rho, u, p = conservative_to_primitive(preds[:, 0:1], preds[:, 1:2], preds[:, 2:3], cfg.gamma)
    rho = rho.numpy().flatten()
    u = u.numpy().flatten()
    p = p.numpy().flatten()

    left_state = StatePrim(**OmegaConf.to_container(cfg.sod.left, resolve=True))
    right_state = StatePrim(**OmegaConf.to_container(cfg.sod.right, resolve=True))
    rho_exact, u_exact, p_exact = sod_exact_profile(
        x - cfg.sod.interface,
        cfg.domain.t_max,
        left_state=left_state,
        right_state=right_state,
        gamma=cfg.gamma,
    )

    fig, axes = plt.subplots(3, 1, figsize=(10, 9), sharex=True)
    axes[0].plot(x, rho_exact, label='Exact')
    axes[0].plot(x, rho, '--', label='PINN')
    axes[0].set_ylabel('Density')
    axes[0].grid(True)
    axes[0].legend()

    axes[1].plot(x, u_exact, label='Exact')
    axes[1].plot(x, u, '--', label='PINN')
    axes[1].set_ylabel('Velocity')
    axes[1].grid(True)

    axes[2].plot(x, p_exact, label='Exact')
    axes[2].plot(x, p, '--', label='PINN')
    axes[2].set_ylabel('Pressure')
    axes[2].set_xlabel('x')
    axes[2].grid(True)

    fig.suptitle('PINN vs Exact (t = {:.3f})'.format(cfg.domain.t_max))
    fig.tight_layout()
    plt.show()
else:
    print('Checkpoint não encontrado em', CHECKPOINT_DIR)


## Métricas consolidadas

Carrega as métricas calculadas pelo script de avaliação para interpretar o desempenho da PINN.


In [None]:
import json
from pathlib import Path

metrics_path = Path('data/artifacts/eval/sod/sod_metrics.json')
if metrics_path.exists():
    metrics = json.loads(metrics_path.read_text(encoding='utf-8'))
    pinn_metrics = metrics.get('PINN', {})
    print('Métricas PINN vs. exata:')
    for key, value in pinn_metrics.items():
        print(f'  {key}: {value}')
else:
    print('Arquivo de métricas não encontrado:', metrics_path)


> **Comentário:** utilize essas métricas para discutir onde a PINN precisa de ajustes (por exemplo, perdas relativas elevadas indicam necessidade de nova calibração).
