# Fitting the SOHO/VIRGO Solar Irradiance Power Spectrum

##### Brett Morris

In [None]:
%matplotlib inline
import os
import json
from functools import partial

from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from astropy.time import Time
import astropy.units as u
from astropy.stats import mad_std

from gadfly import PowerSpectrum
from gadfly.sun import broomhall_p_mode_freqs

from gadfly.psd import (
    linear_space_to_jax_parameterization, 
    jax_parameterization_to_linear_space,
    linear_space_to_dicts, ppm, to_psd_units
)

from scipy.stats import binned_statistic
from lightkurve import LightCurve

from jax import jit
import jax.numpy as jnp
from jax.scipy.optimize import minimize

from celerite2.jax import GaussianProcess, terms

The VIRGO/PMO6 1-minute time series is accessible online at: 

    ftp://ftp.pmodwrc.ch/pub/data/irradiance/virgo/old/1-minute_Data/VIRGO_1min_0083-7404.fits
    
We first load the VIRGO observations:

In [None]:
hdu = fits.open('data/VIRGO_1min_0083-7404.fits.gz')
raw_fluxes = hdu[0].data
header = hdu[0].header

header

We can reconstruct the times from the background info in the header. We will linearly interpolate over missing measurements. 

In [None]:
soho_mission_day = Time("1995-12-1 00:00")

times = (
    soho_mission_day.jd + 
    header['TIME'] + 
    np.arange(header['NAXIS1']) / 1440
)
times_astropy = Time(times, format='jd')

fluxes = raw_fluxes.copy()
interp_fluxes = np.interp(
    times[raw_fluxes == -99], times[raw_fluxes != -99], fluxes[raw_fluxes != -99]
)
d = (times[1] - times[0]) * u.day

fluxes[raw_fluxes == -99] = interp_fluxes

fluxes = 1e6 * (fluxes / np.median(fluxes) - 1) * ppm
fluxes_std_ppm = mad_std(fluxes.value)

skip_every = 500
plt.plot_date(
    times_astropy.plot_date[::skip_every], 
    fluxes[::skip_every], fmt='.'
)
plt.xlabel('Date')
plt.ylabel(f'Flux [{fluxes.unit}]');

Compute the full and binned power spectrum: 

In [None]:
solar_light_curve = LightCurve(
    time=times_astropy, 
    flux=fluxes
)

solar_power_spectrum = PowerSpectrum.from_light_curve(
    solar_light_curve, interpolate_and_detrend=False
)

solar_power_spectrum_binned = solar_power_spectrum.bin(len(fluxes) // 10000)

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))
solar_power_spectrum_binned.plot(
    ax=ax, p_mode_inset=False
)

Narrow our set some boundaries in frequency over which we will fit the power spectrum:

In [None]:
solar_power_spectrum_binned.plot(p_mode_inset=False)

In [None]:
cutoff_freq_max = 5000 * u.uHz
cutoff_freq_min = 0.01 * u.uHz

in_bounds = (
    (solar_power_spectrum_binned.frequency < cutoff_freq_max) &
    (solar_power_spectrum_binned.frequency > cutoff_freq_min)
)

y = to_psd_units(
    solar_power_spectrum_binned.power[in_bounds]
).value

yerr = to_psd_units(
    solar_power_spectrum_binned.error[in_bounds]
).value

x = solar_power_spectrum_binned.frequency[in_bounds].to(u.uHz).value

# # Make sure no NaNs make it into the calculations below:
mask_y = np.logical_not(np.isnan(y))
x = x[mask_y]
yerr = yerr[mask_y]
y = y[mask_y]

mask_fix_yerr = np.isnan(yerr) | (yerr <= 0)
yerr[mask_fix_yerr] = 1e4

Fit the low-frequency features in the solar power spectrum.

In [None]:
fixed_Q = 0.6

min_p_mode = 2300  # Assume p-modes start about here in uHz
max_p_mode = 3900  # Assume p-modes end about here in uHz

mask_p_modes = (x < min_p_mode) | (x > max_p_mode)


def sho(S0, w0, fixed_Q=fixed_Q):
    """
    Underdamped SHO kernel from celerite2
    """
    return terms.UnderdampedSHOTerm(
        S0=S0, w0=w0, Q=fixed_Q
    )

def low_freq_kernels_five_kernels(p):
    """
    Sum of five underdamped SHO kernels, meant for fitting
    low frequency features in the solar power spectrum.
    """
    delta_S0_0, w0_0 = p[0:2]
    delta_S0_1, delta_w0_1 = p[2:4]
    delta_S0_2, delta_w0_2 = p[4:6]
    delta_S0_3, delta_w0_3 = p[6:8]
    S0_4, delta_w0_4 = p[8:10]
    
    S0_3 = 10 ** delta_S0_3 + S0_4
    S0_2 = 10 ** delta_S0_2 + S0_3
    S0_1 = 10 ** delta_S0_1 + S0_2
    S0_0 = 10 ** delta_S0_0 + S0_1
    
    w0_1 = w0_0 + 10 ** delta_w0_1
    w0_2 = w0_1 + 10 ** delta_w0_2
    w0_3 = w0_2 + 10 ** delta_w0_3
    w0_4 = w0_3 + 10 ** delta_w0_4
    
    kernel = terms.TermSum(
        sho(S0_0, w0_0),
        sho(S0_1, w0_1), 
        sho(S0_2, w0_2),
        sho(S0_3, w0_3),
        sho(S0_4, w0_4),
    )
    
    return kernel

def jax_model_low_freq(p, omega):
    """
    Model of the low-frequency power spectrum, computed
    at angular frequencies ``omega``.
    """
    psd = low_freq_kernels_five_kernels(p).get_psd(omega)
    return psd

@partial(jit, static_argnums=(1, 2, 3))
def chi2_low_freq_model(
    p, y=y[mask_p_modes], omega=2*np.pi*x[mask_p_modes], yerr=yerr[mask_p_modes]
):
    """
    chi^2 of the low-frequency model
    """
    chi2_result = jnp.nansum( 
        (y - jax_model_low_freq(p, omega=omega))**2  / yerr**2
    )
    return chi2_result


fit = True
# These parameters work well!
# all_S0s = 3 * 10 ** np.array([4.1, 1.3, -0.5, -0.8, -1.6])
# all_omegas = np.array([5e0, 9.5e1, 6e2, 6.3e3, 2.3e4])

all_S0s = 1.5 * 10 ** np.array([4.1, 1.2, -0.3, -0.8, -1.6])
all_omegas = np.array([5e0, 9.5e1, 6e2, 6.3e3, 2.3e4])

# all_S0s = 1.5 * 10 ** np.array([4.1, 1.3, -0.5, -0.8, -1.6])
# all_omegas = np.array([5e0, 9.5e1, 7e2, 6.6e3, 2.6e4])

initp = linear_space_to_jax_parameterization(all_S0s, all_omegas)

fig, ax = plt.subplots(figsize=(12, 5))

if fit:
    result = minimize(
        chi2_low_freq_model, jnp.array(initp), method='bfgs',
    )
    print('Fit successful:', result.success)
    if result.status == 3:
        # status message keys source/docs
        # https://github.com/scipy/scipy/blob/85d25b6e4a9b95371e48bae75c19459a0b77d18e/scipy/optimize/_optimize.py#L1239-L1242
        print('Warning: nans encountered in optimization')
    bestp_lowfreq = result.x

def numpy_model(p, omega=2*np.pi*x):
    return np.array(jax_model_low_freq(p, omega=omega))

if fit:
    ax.loglog(x, numpy_model(bestp_lowfreq), color='C0', lw=5, label='Fit', zorder=0)

ax.loglog(x, numpy_model(initp), color='r', ls=':', label='Init', zorder=10)
ax.errorbar(x, y, None, color='k', ecolor='silver', label='Binned', fmt='.', rasterized=True, zorder=10)
ax.axvspan(min_p_mode, max_p_mode, alpha=0.2)
ax.set_xlabel('Frequency ($\mu$Hz)')
ax.set_ylabel('Power')
ax.set_xlim([cutoff_freq_min.value, cutoff_freq_max.value])
ax.set_ylim([1e-3, 1e5])
plt.legend()

Helpful diagnostic plot for understanding if the errorbars are over/underestimated:

In [None]:
yerr_diagnostic_plot = False

if yerr_diagnostic_plot: 
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.loglog(
        x, np.abs((y - numpy_model(bestp_lowfreq)) / yerr), 'k.',
    )
    ax.set(
        xlabel='Freq ($\\mu$Hz)', ylabel='abs((O - C) / E)'
    )
    ax.axvspan(min_p_mode, max_p_mode, alpha=0.2)

Parse a list of $p$-mode frequencies from Broomhall 2009 Table 2: 

In [None]:
table2 = broomhall_p_mode_freqs()

Fit each power spectrum peak near the $p$-modes.

* Try larger and larger errorbars if the fit does not converge
* Also try different initial conditions if the fit still does not converge
* If the fit eventually converges, save the best-fit parameters 

In [None]:
table2_peaks_to_fit = table2#[16+18:] #[50:55]
peaks_progress_bar = tqdm(table2_peaks_to_fit)

parameters = []

peak_fitting_width = 6 * u.uHz
core_avoid_width = 3.5 * u.uHz
plot_if_fit_unsuccessful = False
plot_if_fit_successful = True

for peak in peaks_progress_bar:
    power_spectrum_cutout = solar_power_spectrum.cutout(
        frequency_max=peak + peak_fitting_width, 
        frequency_min=peak - peak_fitting_width
    )
    bins = len(power_spectrum_cutout.frequency) // 100
    binned_spectrum = power_spectrum_cutout.bin(bins, log=False)
    
    off_core = (
        (binned_spectrum.frequency > peak + core_avoid_width) | 
        (binned_spectrum.frequency < peak - core_avoid_width)
    )
    omegas_i = 2 * jnp.pi * binned_spectrum.frequency.to(u.uHz).value
    powers_i = to_psd_units(binned_spectrum.power).value
    power_err = mad_std(to_psd_units(binned_spectrum.power[off_core]).value, ignore_nan=True)
    expected_omega = 2 * np.pi * peak.to(u.uHz).value

    initial_tries = [
        jnp.array([-7.3, 3]),
        jnp.array([-6.3, 2.5]),
        jnp.array([-5.3, 2]),
    ]
    def compute_delta_bic(init_chi2, best_chi2, attempt_initp, omegas_i):
        # init low-freq only model BIC:
        n_parameters_low_freq = 0 #len(bestp_lowfreq)
        init_bic = init_chi2 + n_parameters_low_freq * np.log(len(omegas_i))
        # best-fit model BIC:
        n_parameters_per_p_mode = len(attempt_initp) // len(all_S0s)
        final_bic = best_chi2 + (n_parameters_per_p_mode + n_parameters_low_freq) * np.log(len(omegas_i))
        return init_bic - final_bic, final_bic, init_bic

    def attempt_fit(initp, power_err, expected_omega=expected_omega, omegas_i=omegas_i, powers_i=powers_i):
        @partial(jit, static_argnums=(1, 2))
        def jax_model_low_freq(p, expected_omega=expected_omega, omegas_i=omegas_i):
            kernel = low_freq_kernels_five_kernels(bestp_lowfreq)
            psd = kernel.get_psd(omegas_i)
            return psd

        @partial(jit, static_argnums=(1, 2))
        def jax_model(p, expected_omega=expected_omega, omegas_i=omegas_i):
            kernel = (
                terms.UnderdampedSHOTerm(
                    S0=jnp.power(10.0, p[0]), w0=expected_omega, Q=jnp.power(10.0, p[1])
                ) + low_freq_kernels_five_kernels(bestp_lowfreq)
            )
            psd = kernel.get_psd(omegas_i)
            return psd

        @partial(jit, static_argnums=np.arange(1, 5))
        def chi2_underdamped(
            p, powers_i=powers_i, power_err=power_err, 
            omegas_i=omegas_i, expected_omega=expected_omega
        ):
            return jnp.nansum((jax_model(p) - powers_i)**2 / power_err**2)

        @partial(jit, static_argnums=np.arange(1, 5))
        def chi2_low_freq(
            p, powers_i=powers_i, power_err=power_err, 
            omegas_i=omegas_i, expected_omega=expected_omega
        ):
            return jnp.nansum((jax_model_low_freq(p) - powers_i)**2 / power_err**2)

        init_model = jax_model(initp)
        init_low_freq_model = jax_model_low_freq(initp)
        minimizer_result = minimize(chi2_underdamped, initp, method='bfgs')
        best_fit_parameters = minimizer_result.x
        if np.any(minimizer_result.x > 100):
            # prevent any bonkers "solutions" from being computed
            # and blowing up with infs/nans
            best_fit_parameters = initp.copy()
        best_model = jax_model(best_fit_parameters)
        init_chi2 = chi2_low_freq(initp)
        best_chi2 = chi2_underdamped(best_fit_parameters)
        return [
            minimizer_result, init_model, init_low_freq_model, 
            best_model, init_chi2, best_chi2, initp, err_scaling
        ]

    success = False
    error_scales = [1] + np.arange(3, 15, 3).tolist()
    err_scale_progress_bar = tqdm(
        error_scales,
        total=len(error_scales) * len(initial_tries)
    )
    peak_message = f"freq = {expected_omega/(2*np.pi):.1f} uHz"
    peaks_progress_bar.set_description(
        f"[Fitting p-mode {peak_message}]"
    )
    
    for err_scaling in err_scale_progress_bar:
        init_counter = 0
        initial_models = dict()
        best_initial_try = None
        last_delta_bic = None
        while not success and init_counter < len(initial_tries):
            err_scale_progress_bar.set_description(
                f"PSD fitting at {peak_message} [err scale={err_scaling}x, try initial guess {init_counter}]"
            )
            err_scale_progress_bar.update()
            attempt_initp = initial_tries[init_counter]
            attempt_results = attempt_fit(
                attempt_initp, err_scaling * power_err
            )
            [minimizer_result, init_model, init_low_freq_model, 
             best_model, init_chi2, best_chi2, _, _] = attempt_results
            delta_bic, final_bic, init_bic = compute_delta_bic(
                init_chi2, best_chi2, attempt_initp, omegas_i
            )
            if (last_delta_bic is not None and last_delta_bic < delta_bic) or (last_delta_bic is None):
                # print('...∆BIC improving...')
                last_delta_bic = float(delta_bic)
                best_initial_try = list(attempt_results)

            initial_models[init_counter] = init_model
            # reduced_chi2 = best_chi2 / len(omegas_i)
            # print(f"reduced chi2 = {reduced_chi2:.1f}")
            success = minimizer_result.success #or 
            init_counter += 1
            
            if success:
                err_scale_progress_bar.disp(close=True)

        if success:
            break

        elif plot_if_fit_unsuccessful:
            [minimizer_result, init_model, init_low_freq_model, 
             best_model, init_chi2, best_chi2, 
             attempt_initp, err_scaling] = best_initial_try

            reduced_chi2 = best_chi2 / len(omegas_i)
            
            fig, ax = plt.subplots(figsize=(14, 4))
            ax.errorbar(
                binned_spectrum.frequency.value, to_psd_units(binned_spectrum.power).value, 
                to_psd_units(err_scaling * power_err).value, fmt='.', color='gray', ecolor='silver'
            )
            ax.plot(omegas_i / (2*np.pi), to_psd_units(powers_i), lw=2, color='k')
            for j, k in enumerate(initial_models):
                ax.plot(omegas_i / (2*np.pi), to_psd_units(initial_models[k]), lw=2, color=f'C{j}', alpha=0.3)

            ax.plot(omegas_i / (2*np.pi), to_psd_units(best_model), lw=4, color='r')
            ax.set(
                xlabel=f'Frequency [{binned_spectrum.frequency.unit.to_string("latex")}]',
                ylabel=f'Power Density [{to_psd_units(binned_spectrum.power[0]).unit.to_string("latex")}]',
                title=f"Failed fit, $\\tilde{{\\chi}}^2 = {reduced_chi2:.2f}$"
            )
            plt.show()

    if not success:
        err_scale_progress_bar.disp(close=True)
        print(f"minimize unsuccessful, skipping fit for {peak_message}")
        continue

    [minimizer_result, init_model, init_low_freq_model, 
     best_model, init_chi2, best_chi2, 
     attempt_initp, err_scaling] = best_initial_try
    delta_bic_threshold = 10

    delta_bic, final_bic, init_bic = compute_delta_bic(
        init_chi2, best_chi2, attempt_initp, omegas_i
    )
    # print(f"BIC init = {init_bic}\tBIC final = {final_bic}")
    reduced_chi2 = best_chi2 / len(omegas_i)

    # print(f"reduced chi2={reduced_chi2}")
    # print(f"∆BIC={delta_bic}")

    accept_fit = delta_bic > delta_bic_threshold
    
    if accept_fit:
        parameters.append(minimizer_result.x.tolist() + [expected_omega])
        print(f"Successful fit for {peak_message}. ∆BIC={delta_bic:.1f}")
    else: 
        print(f"Successful fit but ∆BIC={delta_bic:.1f}<"
              f"{delta_bic_threshold:.0f} for {peak_message}")

    err_scale_progress_bar.disp(close=True)

    if plot_if_fit_unsuccessful or plot_if_fit_successful:
        fig, ax = plt.subplots(figsize=(14, 4))
        # ax.plot(
        #     f[fit_bounds].value, Pxx_den[fit_bounds].value, ','
        # )
        ax.errorbar(
            binned_spectrum.frequency.value, to_psd_units(binned_spectrum.power).value, 
            to_psd_units(err_scaling * power_err).value, fmt=',', color='gray', ecolor='silver'
        )
        ax.plot(omegas_i / (2*np.pi), to_psd_units(powers_i), lw=2, color='k')
        for j, k in enumerate(initial_models):
            ax.plot(omegas_i / (2*np.pi), to_psd_units(initial_models[k]), lw=2, color=f'C{j}', alpha=0.3, zorder=10+j)

        ax.plot(omegas_i / (2*np.pi), to_psd_units(best_model), lw=4, color='r')
        ax.set(
            xlabel=f'Frequency [{binned_spectrum.frequency.unit.to_string("latex")}]',
            ylabel=f'Power Density [{binned_spectrum.frequency.unit.to_string("latex")}]',
            # ylim=[0.5 * min(best_model.min(), powers_i.min()), 
            #       1.25 * max(best_model.max(), powers_i.max())],
            title=(
                f"{'Accepted' if accept_fit else 'Rejected'} successful fit "
                f"with ∆BIC={delta_bic:.1f}" if success else 
                f"Failed fit, $\\tilde{{\\chi}}^2 = {reduced_chi2:.2f}$"
            )
        )
        plt.show()
parameters = np.array(parameters)

The best-fit parameters are in a numpy array with columns: 

    [log10(S0), log10(Q), w0 in uHz]
    
Now construct the `full_kernel` which has the low frequency components and the p-modes:

In [None]:
full_kernel = low_freq_kernels_five_kernels(bestp_lowfreq)

# begin to save the kernel parameters with the low frequency
# kernel solutions found in the fits to the granulation and
# supergranulation, etc. These have fixed Q's:
kernel_parameters = linear_space_to_dicts(
    *jax_parameterization_to_linear_space(bestp_lowfreq), fixed_Q=fixed_Q
)

# Then add a list of dicts for parameters that describe
# the comb of p-mode peaks. These have fixed w0's:
for p in parameters:
    kwargs = dict(
        S0=np.power(10.0, p[0]),
        w0=p[2],
        Q=np.power(10.0, p[1])
    )
    full_kernel += terms.UnderdampedSHOTerm(**kwargs)
    kernel_parameters.append(kwargs)
full_psd = full_kernel.get_psd(2 * np.pi * solar_power_spectrum.frequency.to(u.uHz).value)

plt.loglog(solar_power_spectrum.frequency.value, full_psd)
plt.loglog(solar_power_spectrum_binned.frequency.value, solar_power_spectrum_binned.power.value, '.k')

Save the list of dictionaries of kernel hyperparameters to a json file:

In [None]:
overwrite = True
parameter_vector_path = 'parameter_vector.json'
if not os.path.exists(parameter_vector_path) or overwrite:
    with open(parameter_vector_path, 'w') as w:
        json.dump(kernel_parameters, w, indent=4)

    parameters_with_metadata = []
    for p in kernel_parameters:
        # Handle the low-freq and high-freq fits differently.
        # those with Q's identical to the fixed Q are fitted
        # with fixed Q. All others are assumed to have fixed w0.
        parameters_with_metadata.append(
            dict(
                hyperparameters=p, 
                metadata=dict(
                    fixed_parameters=['Q'] if p['Q'] == fixed_Q else ['w0']
                )
            )
        )

    with open('hyperparameters.json', 'w') as w:
        json.dump(parameters_with_metadata, w, indent=4)

else: 
    print('skipping overwrite')

In [None]:
from gadfly import Hyperparameters, StellarOscillatorKernel, PowerSpectrum
from gadfly.sun import download_soho_virgo_time_series

# Generate a celerite2 kernel that approximates the solar
# power spectrum
hp = Hyperparameters(json.load(open('hyperparameters.json', 'r')))
new_kernel = StellarOscillatorKernel(hp, name='new solar')

# Plot the kernel's PSD, and the observed (binned) solar PSD:
obs_kwargs = dict(zorder=-10, marker='.', color='k', lw=0)
fig, ax = new_kernel.plot(
    # also plot the observed power spectrum
    obs=solar_power_spectrum_binned, obs_kwargs=obs_kwargs
)

Compare the above figure with [Fröhlich 1997](https://link.springer.com/article/10.1023/A:1004969622753)!