# Neural WFA Inversion (Refactored)

This notebook demonstrates the usage of the refactored `neural_wfa` package for inverting solar spectropolarimetric data using Neural Fields.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import astropy.io.fits as fits
import os, sys

# Ensure src is in path if running locally
sys.path.append("src")
sys.path.append("../src")

from neural_wfa.core.observation import Observation
from neural_wfa.core.problem import WFAProblem
from neural_wfa.physics.line_info import LineInfo
from neural_wfa.nn.mlp import MLP
from neural_wfa.optimization.solver import NeuralSolver
from neural_wfa.core.magnetic_field import MagneticField
from neural_wfa.analysis.uncertainty import estimate_uncertainties_diagonal

## 1. Load Data

In [None]:
datadir = "example/plage_sst/"
if not os.path.exists(datadir):
    datadir = "../example/plage_sst/"

img = np.ascontiguousarray(
    fits.open(datadir + "CRISP_5173_plage_dat.fits", "readonly")[0].data,
    dtype="float32",
)
xl = np.ascontiguousarray(
    fits.open(datadir + "CRISP_5173_plage_wav.fits", "readonly")[0].data,
    dtype="float32",
)

print("Data shape:", img.shape)
ny, nx, ns, nw = img.shape

## 2. Setup Problem

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Observation
obs = Observation(img, xl, mask=[5, 6, 7], device=str(device))

# Line Parameters
lin = LineInfo(5173)

# WFA Physics Engine
problem = WFAProblem(obs, lin, device=device)

## 3. Initialize Neural Fields

In [None]:
# Coordinate Grid (normalized -1 to 1)
y = np.linspace(-1, 1, ny)
x = np.linspace(-1, 1, nx)
YY, XX = np.meshgrid(y, x, indexing='ij')
coords = np.stack([YY, XX], axis=-1).reshape(-1, 2)
coords = torch.from_numpy(coords.astype(np.float32)).to(device)

# Model for Blos (Line-of-Sight Magnetic Field)
model_blos = MLP(
    dim_in=2,
    dim_out=1,
    dim_hidden=64,
    num_resnet_blocks=2,
    fourier_features=True,
    m_freqs=512,
    sigma=40.0,
    tune_beta=False
)

# Model for BQU (Transverse Magnetic Field Components)
model_bqu = MLP(
    dim_in=2,
    dim_out=2,
    dim_hidden=64,
    num_resnet_blocks=2,
    fourier_features=True,
    m_freqs=512,
    sigma=8.0,
    tune_beta=False
)

## 4. Train using Neural Solver

In [None]:
solver = NeuralSolver(
    problem=problem,
    model_blos=model_blos,
    model_bqu=model_bqu,
    coordinates=coords,
    lr=5e-4,
    batch_size=200000,
    device=device
)

print("Training Phase 1: Blos Only...")
solver.train(n_epochs=200, optimize_blos=True, optimize_bqu=False)

print("Training Phase 2: BQU Only...")
solver.train(n_epochs=200, optimize_blos=False, optimize_bqu=True)

## 5. Visualize Results & Analysis

In [None]:
final_field = solver.get_full_field()

# 1. Magnetic Field Maps
blos_map = final_field.blos.detach().cpu().numpy().reshape(ny, nx)
bq_map = final_field.b_q.detach().cpu().numpy().reshape(ny, nx)
bu_map = final_field.b_u.detach().cpu().numpy().reshape(ny, nx)
btrans_map = np.sqrt(bq_map**2 + bu_map**2)

plt.figure(figsize=(15, 5))
plt.subplot(131); plt.imshow(blos_map, cmap='gray', origin='lower'); plt.title("Blos (LOS Field)"); plt.colorbar()
plt.subplot(132); plt.imshow(btrans_map, cmap='viridis', origin='lower'); plt.title("Btrans"); plt.colorbar()
plt.subplot(133); plt.imshow(np.arctan2(bu_map, bq_map), cmap='hsv', origin='lower'); plt.title("Azimuth"); plt.colorbar()
plt.tight_layout(); plt.show()

# 2. Profile Fitting Check
# Select a pixel with strong signal
py, px = 100, 100
idx = py * nx + px
indices = torch.tensor([idx], device=device)

# Compute Model Profiles
field_sub = MagneticField(
    final_field.blos[indices],
    torch.stack([final_field.b_q[indices], final_field.b_u[indices]], dim=-1)
)

stokesQ, stokesU, stokesV = problem.compute_forward_model(field_sub, indices=indices)

obs_Q = obs.stokes_Q[indices].detach().cpu().numpy().flatten()
obs_U = obs.stokes_U[indices].detach().cpu().numpy().flatten()
obs_V = obs.stokes_V[indices].detach().cpu().numpy().flatten()
mod_Q = stokesQ.detach().cpu().numpy().flatten()
mod_U = stokesU.detach().cpu().numpy().flatten()
mod_V = stokesV.detach().cpu().numpy().flatten()
wav = obs.wavelengths.detach().cpu().numpy()

plt.figure(figsize=(12, 4))
plt.subplot(131); plt.plot(wav, obs_Q, 'ok', label='Obs'); plt.plot(wav, mod_Q, '-r', label='WFA'); plt.title("Stokes Q"); plt.legend()
plt.subplot(132); plt.plot(wav, obs_U, 'ok'); plt.plot(wav, mod_U, '-r'); plt.title("Stokes U")
plt.subplot(133); plt.plot(wav, obs_V, 'ok'); plt.plot(wav, mod_V, '-r'); plt.title("Stokes V")
plt.tight_layout(); plt.show()

# 3. Loss (Chi2) Map (Approximation using full field)
loss_val = problem.compute_loss(final_field).item()
print(f"Total Loss: {loss_val:.2f}")

# 4. Uncertainty Estimation (Analytical)
# Neural fields provide smooth solutions, but we can still estimate analytical error propagation.
print("Estimating Uncertainties...")
unc = estimate_uncertainties_diagonal(problem, final_field)
sigma_blos = unc['sigma_blos'].reshape(ny, nx)
sigma_btrans = unc['sigma_btrans'].reshape(ny, nx)

plt.figure(figsize=(10, 4))
plt.subplot(121); plt.imshow(sigma_blos, cmap='inferno', origin='lower', vmin=0, vmax=50); plt.title("Sigma Blos [G]"); plt.colorbar()
plt.subplot(122); plt.imshow(sigma_btrans, cmap='inferno', origin='lower', vmin=0, vmax=200); plt.title("Sigma Btrans [G]"); plt.colorbar()
plt.tight_layout(); plt.show()