# CW fit short time dependence
Trying to explore what the fit does at times less than the cycle period

In [1]:
import numpy as np
import plotly.graph_objects as go
from typing import Dict

In [2]:
import math_functions
import numpy as np
import numpy.typing as npt
import plotly.graph_objects as go
from scipy.optimize import curve_fit

from math_functions import sine_wave
from plotly.subplots import make_subplots
from tqdm import tqdm
from typing import Tuple
from multiprocessing import Pool, cpu_count

from typing import Tuple

np_rng = np.random.default_rng(42)


# Wrapper for curve_fit
def wrapper(
    x: npt.NDArray[np.float64],
    frequency: float,
    amplitude: float,
    phase: float,
) -> npt.NDArray[np.float64]:
    return sine_wave(
        x,
        frequency=frequency,
        phase=phase,
        amplitude=amplitude,
    )


key_order = ("frequency", "amplitude", "phase")

# True model parameters
true_params: Dict[str, float] = dict(frequency=1.0, amplitude=1.0, phase=np.pi / 2)

# Settings
noise_scale = 0.01
num_repeats: int = 30
num_points: int = 100  # points that the curve is sampling
# durations: npt.NDArray[np.float64] = np.linspace(0.1, 100, 100)
durations = np.logspace(-3.0, 2, 300)


# the problem is that the sampling can alias the true frequency, so we pick durations that do not alias to a frequency that is close to the niquist or zero frequency
# Step 1: Compute sampling frequencies
f_s_vals = num_points / durations
# Step 2: Compute apparent frequencies
f_a_vals = np.array(
    [
        math_functions.apparent_frequency(true_params["frequency"], f_s)
        for f_s in f_s_vals
    ]
)
# Step 3: Normalize apparent frequencies by f_true
f_a_norm = f_a_vals / true_params["frequency"]
f_N_norm = f_s_vals / (2 * true_params["frequency"])
# Step 4: Mask out undesired edge cases
min_thresh = 0.05
max_thresh = 0.95
valid: npt.NDArray[np.bool_] = (f_a_norm > min_thresh) & (f_a_norm < max_thresh)

# Step 5: Adjust durations for invalid cases by nudging a fraction of the Nyquist zone
delta: float = 0.1  # fraction of Nyquist zone width
adjusted_durations: npt.NDArray[np.float64] = durations.copy()

for idx in np.where(~valid)[0]:
    f_s = f_s_vals[idx]
    N = int(np.floor(2 * true_params["frequency"] / f_s)) + 1
    f_s_zone_width = 2 * true_params["frequency"] * (1 / N - 1 / (N + 1))
    f_s_offset = delta * f_s_zone_width
    direction = 1 if N % 2 == 0 else -1
    f_s_adjusted = f_s + direction * f_s_offset
    adjusted_durations[idx] = num_points / f_s_adjusted

# Final assignment
durations = adjusted_durations


# Storage
rmse_results = {k: [] for k in true_params}
stderr_results = {k: [] for k in true_params}


def fit_for_duration(
    duration: float,
) -> Tuple[float, Dict[str, float], Dict[str, float]]:
    x_vals = np.linspace(0, duration, num_points)

    param_diffs = {k: [] for k in true_params}
    param_stderr = {k: [] for k in true_params}

    for _ in range(num_repeats):
        y_clean = sine_wave(x_vals, **true_params)
        noise = np_rng.normal(scale=noise_scale, size=y_clean.shape)
        y_noisy = y_clean + noise

        try:
            initial_guess = np.array([true_params[k] for k in key_order])
            initial_guess += (
                np_rng.normal(scale=0.00001, size=initial_guess.shape) * initial_guess
            )
            popt, pcov = curve_fit(wrapper, x_vals, y_noisy, p0=initial_guess)
            perr = np.sqrt(np.diag(pcov))

            for i, key in enumerate(true_params):
                param_diffs[key].append(popt[i] - true_params[key])
                param_stderr[key].append(perr[i])
        except RuntimeError:
            for key in true_params:
                param_diffs[key].append(np.nan)
                param_stderr[key].append(np.nan)

    rmse = {
        key: np.sqrt(np.nanmean(np.square(param_diffs[key]))) for key in true_params
    }
    stderr = {key: np.nanmean(param_stderr[key]) for key in true_params}

    return duration, rmse, stderr


print("Starting parallel fits...")
with Pool(processes=cpu_count()) as pool:
    results = list(tqdm(pool.imap(fit_for_duration, durations), total=len(durations)))

rmse_results = {k: [] for k in true_params}
stderr_results = {k: [] for k in true_params}

# Extract from results
durations_out = []
for duration, rmse, stderr in results:
    durations_out.append(duration)
    for key in true_params:
        rmse_results[key].append(rmse[key])
        stderr_results[key].append(stderr[key])

durations = np.array(durations_out)

# Plotting with Plotly
# Create subplots
fig = make_subplots(rows=4, cols=1, shared_xaxes=True, vertical_spacing=0.01)

for i, k in enumerate(true_params, start=1):
    fig.add_trace(
        go.Scatter(
            x=durations, y=rmse_results[k], mode="lines+markers", name=f"RMSE {k}"
        ),
        row=i,
        col=1,
    )
    fig.add_trace(
        go.Scatter(
            x=durations,
            y=stderr_results[k],
            mode="lines+markers",
            name=f"stderr {k}",
            line=dict(dash="dot"),
        ),
        row=i,
        col=1,
    )

    xref = "x domain" if i == 1 else f"x{i} domain"
    yref = "y domain" if i == 1 else f"y{i} domain"
    fig.add_annotation(
        text=f"<b>{k}</b>",
        x=0.5,
        y=0.95,
        xref=xref,
        yref=yref,
        showarrow=False,
        font=dict(size=14),
        align="left",
    )

fig.update_layout(
    title="Fit RMSE and stderr vs Sample Duration (each parameter in separate subplot)",
    xaxis4_title="Sample Duration (s)",
    height=900,
    width=900,
    showlegend=True,
)

for i in range(1, 4):
    fig.update_yaxes(type="log", row=i, col=1)
    fig.update_xaxes(type="log", row=i, col=1)


fig.update_xaxes(type="log", title="Sample Duration (s)", row=4, col=1)

fig.show()

Starting parallel fits...


100%|██████████| 300/300 [00:02<00:00, 123.54it/s]


In [3]:
import anal_fit_err
import importlib
import dataclasses

importlib.reload(anal_fit_err)

# Generate theoretical predictions
theory_cw = {
    "amplitude": [],
    "frequency": [],
    "phase": [],
}

theory_cw_short = {
    "amplitude": [],
    "frequency": [],
    "phase": [],
}


for this_duration in durations:

    est = anal_fit_err.analy_err_in_fit_cw_sine(
        amplitude=true_params["amplitude"],
        samp_num=num_points,
        samp_time=this_duration,
        sigma_obs=noise_scale,
    )
    theory_cw["amplitude"].append(est.amplitude)
    theory_cw["frequency"].append(est.frequency)
    theory_cw["phase"].append(est.phase)

    est = anal_fit_err.analy_err_in_fit_cw_sine_short_dur(
        amplitude=true_params["amplitude"],
        frequency=true_params["frequency"],
        samp_num=num_points,
        samp_time=this_duration,
        sigma_obs=noise_scale,
        phase=true_params["phase"],
    )
    theory_cw_short["amplitude"].append(est.amplitude)
    theory_cw_short["frequency"].append(est.frequency)
    theory_cw_short["phase"].append(est.phase)


for key in true_params:
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=durations, y=rmse_results[key], mode="lines+markers", name=f"RMSE {key}"
        )
    )
    fig.add_trace(
        go.Scatter(
            x=durations,
            y=stderr_results[key],
            mode="lines+markers",
            name=f"stderr {key}",
            line=dict(dash="dot"),
        )
    )

    fig.add_trace(
        go.Scatter(
            x=durations,
            y=theory_cw[key],
            mode="lines",
            name=f"theory long {key}",
            line=dict(dash="dash"),
        )
    )

    fig.add_trace(
        go.Scatter(
            x=durations,
            y=theory_cw_short[key],
            mode="lines",
            name=f"theory short {key}",
            line=dict(dash="dash"),
        )
    )

    fig.update_layout(
        title=f"Comparison for {key}",
        xaxis_title="Sample Duration (s)",
        yaxis_title="Uncertainty",
        width=900,
        height=500,
        yaxis_type="log",
        xaxis_type="log",
    )
    fig.show()