# Spectral PINN for Heat Equation Demo

This notebook demonstrates the Spectral Physics-Informed Neural Network (SPINN) approach for solving the 1D Heat Equation.
The model learns to map an initial condition $u(x,0)$ to its spectral coefficients $c_k$, allowing for analytical time evolution.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Ensure src is in path
if 'src' not in sys.path:
    sys.path.append(os.path.join(os.getcwd(), 'src'))

from model import SpectralPINN
from data import generate_initial_condition

## Configuration and Model Initialization

In [None]:
# Parameters
L = 1.0
NU = 0.01
NUM_MODES = 20
HIDDEN_DIM = 64
U0_SAMPLES = 50
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize Model
model = SpectralPINN(num_modes=NUM_MODES, 
                     hidden_dim=HIDDEN_DIM, 
                     input_dim=U0_SAMPLES, 
                     L=L, 
                     nu=NU).to(device)

## Training Loop

We train the model to reconstruct the initial condition $u(x,0)$ using the spectral basis.

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_history = []

print("Starting training...")
for epoch in range(1001):
    optimizer.zero_grad()
    
    # 1. Generate Data (Random initial conditions)
    # x_fn are the sample locations, u0_vals are the function values
    x_fn, u0_vals = generate_initial_condition(BATCH_SIZE, U0_SAMPLES, L=L, type='mix_sine')
    u0_vals = u0_vals.to(device)
    
    # Input to network is the discretized function
    u0_input = u0_vals
    
    # Query points for loss evaluation (same as input points for reconstruction)
    # x_query needs to be (Batch, N, 1)
    x_query = x_fn.unsqueeze(-1).to(device)
    t_zeros = torch.zeros(BATCH_SIZE, U0_SAMPLES, 1).to(device)
    
    # Forward pass at t=0
    u_rec_0 = model(u0_input, x_query, t_zeros)
    
    # Loss: Reconstruction error of initial condition
    loss = nn.MSELoss()(u_rec_0.squeeze(-1), u0_vals)
    
    loss.backward()
    optimizer.step()
    loss_history.append(loss.item())
    
    if epoch % 200 == 0:
        print(f"Epoch {epoch}: Loss {loss.item():.6f}")

plt.plot(loss_history)
plt.title("Training Loss (Initial Condition Reconstruction)")
plt.yscale('log')
plt.show()

## Visualization of Solution

We can now visualize the solution $u(x,t)$ for a test case. Since the time evolution is handled by the spectral layer analytically ($e^{-\lambda_k t}$), a good reconstruction at $t=0$ implies a correct solution for $t>0$.

In [None]:
# Test case
test_x, test_u0 = generate_initial_condition(1, U0_SAMPLES, L=L, type='mix_sine')
test_u0 = test_u0.to(device)

# Create a grid for visualization
t_eval = torch.linspace(0, 5, 50).to(device) # Time from 0 to 0.5
x_eval = torch.linspace(0, L, 50).to(device)

# Create mesh
T_grid, X_grid = torch.meshgrid(t_eval, x_eval, indexing='ij')

# Prepare inputs for model
# We need to repeat the u0 input for all query points, or just use batch size 1
u0_input_test = test_u0 # (1, U0_SAMPLES)

# Flatten grids for batch processing
t_flat = T_grid.reshape(1, -1, 1)
x_flat = X_grid.reshape(1, -1, 1)

with torch.no_grad():
    u_pred = model(u0_input_test, x_flat, t_flat)
    
u_pred_grid = u_pred.reshape(T_grid.shape).cpu().numpy()

# Plot
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.imshow(u_pred_grid, aspect='auto', origin='lower', extent=[0, L, 0, 0.5])
plt.colorbar(label='u(x,t)')
plt.xlabel('x')
plt.ylabel('t')
plt.title('Predicted Solution (Spectral PINN)')

plt.subplot(1, 2, 2)
# Compare t=0 reconstruction
u_pred_0 = u_pred_grid[0, :]
plt.plot(x_eval.cpu().numpy(), u_pred_0, 'b--', label='Reconstruction')
plt.plot(test_x.flatten().numpy(), test_u0.cpu().flatten().numpy(), 'r.', label='Ground Truth Samples')
plt.title('t=0 Reconstruction')
plt.legend()

plt.tight_layout()
plt.show()