# PI-fMI-Fractional-PDE Analysis

Visualization and results analysis for Physics-Informed Neural Network solving Fractional PDEs.

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
from pathlib import Path

from src.model import PINN
from src.dataset import create_test_grid
from src.utils import (
    exact_solution_example,
    compute_errors,
    plot_comparison,
    plot_training_history,
)

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

## Load Configuration and Model

In [None]:
# Load config
with open('../configs/default.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

# Problem parameters
alpha = config['problem']['alpha']
T = config['problem']['T']
x_min = config['problem']['x_min']
x_max = config['problem']['x_max']

print(f'Problem: alpha={alpha}, T={T}, x in [{x_min}, {x_max}]')

In [None]:
# Create model
model = PINN(
    input_dim=config['model']['input_dim'],
    output_dim=config['model']['output_dim'],
    hidden_layers=config['model']['hidden_layers'],
    activation=config['model']['activation'],
).to(device)

# Load trained weights (update path as needed)
checkpoint_path = '../outputs/checkpoints/best_model.pt'
if Path(checkpoint_path).exists():
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']}")
else:
    print("No checkpoint found. Train the model first.")

model.eval()

## Evaluate on Test Grid

In [None]:
# Create test grid
N_test = 100
t_grid, x_grid, T_mesh, X_mesh = create_test_grid(
    t_range=(0.0, T),
    x_range=(x_min, x_max),
    N_t=N_test,
    N_x=N_test,
    device=device,
)

# Predict
with torch.no_grad():
    u_pred = model(t_grid, x_grid)
    u_pred = u_pred.reshape(N_test, N_test)

# Exact solution
u_exact = exact_solution_example(T_mesh, X_mesh, alpha)

print(f'Prediction shape: {u_pred.shape}')

## Error Analysis

In [None]:
# Compute errors
errors = compute_errors(u_pred, u_exact)

print('Error Metrics:')
for name, value in errors.items():
    print(f'  {name}: {value:.6e}')

## Visualization

In [None]:
# Plot comparison
plot_comparison(T_mesh, X_mesh, u_pred, u_exact)

In [None]:
# Plot training history
import json

history_path = '../outputs/logs/training_history.json'
if Path(history_path).exists():
    with open(history_path, 'r') as f:
        history = json.load(f)
    plot_training_history(history)
else:
    print('No training history found.')

## Solution at Different Time Snapshots

In [None]:
# Plot solution at different times
times = [0.0, 0.25, 0.5, 0.75, 1.0]
x_plot = torch.linspace(x_min, x_max, 100, device=device)

fig, ax = plt.subplots(figsize=(10, 6))

for t_val in times:
    t_tensor = torch.full_like(x_plot, t_val)
    with torch.no_grad():
        u_t = model(t_tensor.unsqueeze(-1), x_plot.unsqueeze(-1))
    ax.plot(x_plot.cpu(), u_t.cpu(), label=f't={t_val}')

ax.set_xlabel('x')
ax.set_ylabel('u(t, x)')
ax.set_title('PINN Solution at Different Times')
ax.legend()
plt.show()