# WFA constraint

## 1. Prerrequisites

In [1]:
import numpy as np
import torch
import torch.nn as nn

from scipy.interpolate import interp1d
from torchinfo import summary
from tqdm import tqdm
import matplotlib.pyplot as plt

import astropy.units as u
from astropy.constants import c, e, m_e

import sys
sys.path.append('..')
from modules_2.charge_data import DataCharger

torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


Import the MSCNN inversion model.

In [2]:
from modules_2.nn_inversion_model import MSCNNInversionModel
mscnn_model = MSCNNInversionModel(scales=[1, 2, 3], in_channels=2, c1_filters=16, c2_filters=32, kernel_size=5, stride=1, padding=0, pool_size=2, n_linear_layers=4, output_features=3*21).to(device)


And charge the data for testing.

In [3]:
# Import the DataCharger class
import sys
sys.path.append('../modules_2')
from charge_data import DataCharger

# Example usage with multiple files
filenames = ["080000", "081000", "082000"]  # Add more filenames as needed

# Initialize the DataCharger
data_charger = DataCharger(
    data_path="/scratchsan/observatorio/juagudeloo/data",
    filenames=filenames,
    nx=480,
    ny=480,
    nz=256
)

data_charger.charge_all_files()
stokes_data, muram_data, wfa_blos_minmax, best_muram_B_minmax  = data_charger.reshape_for_training()
stokes_data = torch.tensor(stokes_data[:3], dtype=torch.float32).to(device)
print(stokes_data.size())

Charging 3 files...
Processing file: 080000


Mapping to optical depth atm parameter 0: 100%|██████████| 480/480 [00:06<00:00, 71.20it/s]
Mapping to optical depth atm parameter 1: 100%|██████████| 480/480 [00:06<00:00, 70.94it/s]
Mapping to optical depth atm parameter 2: 100%|██████████| 480/480 [00:06<00:00, 70.95it/s]
Applying LSF: 100%|██████████| 2/2 [00:00<00:00,  2.06it/s]
Resampling wavelengths: 100%|██████████| 2/2 [00:04<00:00,  2.39s/it]
Adding noise: 100%|██████████| 480/480 [00:02<00:00, 166.79it/s]


Processing file: 081000


Mapping to optical depth atm parameter 0: 100%|██████████| 480/480 [00:06<00:00, 70.71it/s]
Mapping to optical depth atm parameter 1: 100%|██████████| 480/480 [00:06<00:00, 71.13it/s]
Mapping to optical depth atm parameter 2: 100%|██████████| 480/480 [00:06<00:00, 71.33it/s]
Applying LSF: 100%|██████████| 2/2 [00:00<00:00,  2.08it/s]
Resampling wavelengths: 100%|██████████| 2/2 [00:04<00:00,  2.32s/it]
Adding noise: 100%|██████████| 480/480 [00:02<00:00, 166.87it/s]


Processing file: 082000


Mapping to optical depth atm parameter 0: 100%|██████████| 480/480 [00:06<00:00, 71.13it/s]
Mapping to optical depth atm parameter 1: 100%|██████████| 480/480 [00:06<00:00, 71.01it/s]
Mapping to optical depth atm parameter 2: 100%|██████████| 480/480 [00:06<00:00, 71.19it/s]
Applying LSF: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s]
Resampling wavelengths: 100%|██████████| 2/2 [00:04<00:00,  2.38s/it]
Adding noise: 100%|██████████| 480/480 [00:02<00:00, 166.40it/s]


Data charging completed!
Stokes reshaped: (691200, 2, 112)
MuRAM reshaped: (691200, 63)
WFA B_LOS reshaped: (691200, 1)
Best MuRAM B reshaped: (691200, 1)
torch.Size([3, 2, 112])


## 2. WFA loss definition

### 2.1 WFA $B_\text{LOS}$

To first asses the definition of the WFA loss, we must create a function that calculates the $B_\text{LOS}$ based on the Stokes parameters information.

In [4]:
def compute_wfa_blos_pixel(stokes_pixel: np.ndarray,
                          ll: np.ndarray,
                          start_ll: int,
                          end_ll: int,
                          llambda0: float,
                          g: float,
                          stokes_v_index: int):
    """
    Estimate the line-of-sight magnetic field for a single pixel from Stokes profiles.
    Args:
    stokes_pixel: numpy.ndarray
        2D array of Stokes profiles for a single pixel, shape (n_stokes, n_wl)
    ll: numpy.ndarray
        Wavelength axis in angstroms
    start_ll: int
        Index of the starting wavelength range
    end_ll: int
        Index of the ending wavelength range
    llambda0: float
        Rest wavelength in angstroms
    g: float
        Landé factor
    stokes_v_index: int
        Index of the Stokes V profile in the data
    Returns:
    astropy.units.Quantity
        Line-of-sight magnetic field in Gauss
    """
    wfa_constant = e.si / (4 * np.pi) / m_e / c
    wfa_constant = wfa_constant.to(1 / u.G / u.Angstrom)
    
    # Calculate derivative of Stokes I with respect to wavelength
    dI_dl = np.gradient(stokes_pixel[0, start_ll:end_ll]) / np.gradient(ll[start_ll:end_ll])
    
    # Get Stokes V for the wavelength range
    V = stokes_pixel[stokes_v_index, start_ll:end_ll]
    
    # Set up least squares system
    ND = len(V)
    a = np.zeros([ND, 2])
    a[:, 0] = dI_dl[:]
    a[:, 1] = 1.0
    b = V[:]
    
    # Solve least squares
    p = np.linalg.pinv(a) @ b / dI_dl.unit if hasattr(dI_dl, 'unit') else np.linalg.pinv(a) @ b
    
    # Compute B_LOS
    B = -p[0] * u.Angstrom / (wfa_constant * (llambda0)**2.0 * g)
    
    return B


Calculate the $B_\text{LOS}$ for each sample of the test stokes array.

In [5]:
for i in range(stokes_data.size(0)):
    stokes_pixel = stokes_data[i].cpu().numpy()
    ll = data_charger.wl_hinode
    start_ll = 20
    end_ll = 60
    llambda0 = 6301.5*u.Angstrom  # Example rest wavelength
    g = 1.67  # Example Landé factor
    stokes_v_index = 1  # Assuming Stokes V is at index 1 (stokes = norm_stokes_with_noise, ll = wl*u.Angstrom, start_ll = start_ll, end_ll = end_ll, llambda0 = llambda0, g = g, stokes_v_index = stokes_v_index)


    B_los = compute_wfa_blos_pixel(stokes_pixel, ll, start_ll, end_ll, llambda0, g, stokes_v_index)
    print(f"Sample {i}: B_LOS = {B_los:.2f}")

Sample 0: B_LOS = 2.21 G
Sample 1: B_LOS = -5.70 G
Sample 2: B_LOS = -0.29 G
