In [1]:
%load_ext autoreload
%autoreload 3 -p

In [108]:
import numpy as np
from dask.diagnostics import ProgressBar

import xarray
import pandas as pd
import holoviews as hv
hv.extension("bokeh")
from holoviews.plotting.util import colorcet_cmap_to_palette
import colorcet as cc
from src.estimator_analysis import get_results

hv.opts.defaults(
    hv.opts.Scatter(frame_height=400, frame_width=400, tools=["hover"], show_grid=True),
    hv.opts.Image(frame_height=400, frame_width=400, tools=["hover"], show_grid=True, colorbar=True),
    hv.opts.Path(frame_height=300, frame_width=400, tools=["hover"], show_grid=True, line_width=2),
    hv.opts.VectorField(frame_height=400, frame_width=400, show_grid=True, tools=[]),
)

In [151]:
DATA_FILE = "/disk1/tid/users/starr/results/20260126_2340_data_planar.zarr"
RESULTS_FILE = "/disk1/tid/users/starr/results/20260126_2340_results_planar.zarr"

data = xarray.open_zarr(DATA_FILE)
results = get_results(RESULTS_FILE)
results["patch_freq_snr"] = np.log10(results["patch_freq_snr"])
results["rel_speed_error"] = results["speed_error"] / results["spd"]
results["rel_period_error"] = results["period_error"] / results["tau"]
results["rel_wavelength_error"] = results["wavelength_error"] / results["lam"]

results["rel_speed_error_alt"] = (results["alt_phase_speed"] - results["spd"]) / results["spd"]
krms = 1000 * results["weighted_freq_mean"] / results["alt_phase_speed"]
alt_wavelength = 1 / krms
results["rel_wavelength_error_alt"] = (alt_wavelength - results["lam"]) / results["lam"]

with ProgressBar():
    results.compute()

r = results.median(["time"])

[########################################] | 100% Completed | 166.78 s


In [152]:
df = r.to_dataframe()
df2 = results.to_dataframe()
df2 = df2[~df2.isnull().any(axis=1)]

In [153]:
Y_vars = ["rel_period_error", "rel_wavelength_error", "rel_speed_error", "rel_speed_error_alt", "rel_wavelength_error_alt"]
P = None
for Y in Y_vars:
    y = df2[Y]
    hist = np.histogram(y, bins="doane", range=np.quantile(y, [0.001, 0.999]), density=True)
    p = hv.Histogram(hist, kdims=Y, vdims=Y+"density").opts(
        frame_width=350,
        frame_height=350,
        show_grid=True,
        show_legend=False,
        tools=["hover"],
    ).relabel(f"{np.mean(y):.2f}, {np.median(y):.2f}")
    
    if P is None:
        P = p
    else:
        P += p
P.cols(3)

In [67]:
X = "rel_period_error"
Y = "rel_wavelength_error"

xlim = np.quantile(df2[X], [.001, .999])
ylim = np.quantile(df2[Y], [.001, .999])
m = (df2[X] >= xlim[0]) & (df2[X] <= xlim[1]) & (df2[Y] >= ylim[0]) & (df2[Y] <= ylim[1])
hv.HexTiles(df2[m], kdims=[X, Y]).opts(
    frame_width=500,
    frame_height=500,
    show_grid=True,
    show_legend=False,
    tools=["hover"],
    gridsize=100,
    min_count=100,
    colorbar=True,
)

In [None]:
from scipy.signal.windows import kaiser
from scipy.fft import fft2, fftfreq

trial = 50
D = data.sel(trial=trial).compute()

block_size = 32
step_size = 8
hres = 20
Nfft = 128
edges = block_size // (2 * step_size)
window = kaiser(block_size, 5)
window = np.outer(window, window) / np.sum(window)

patches = (
    D.image
    .rolling(y=block_size, x=block_size, center=True)
    .construct(x="kx", y="ky", stride=step_size)
    .pipe(lambda x: x * window)
    .isel(x=slice(edges, -edges), y=slice(edges, -edges))
    .rename({"x": "px", "y": "py"})
)
patches = (patches - patches.mean(["kx", "ky"])) / patches.std(["kx", "ky"])
wavenum = fftfreq(Nfft, hres)

F = (
    xarray.apply_ufunc(
        lambda x: fft2(x * window, s=(Nfft, Nfft)),
        patches,
        input_core_dims=[["kx", "ky"]],
        output_core_dims=[["kx", "ky"]],
        output_dtypes=[np.complex128],
        dask_gufunc_kwargs={"output_sizes": {"kx": Nfft, "ky": Nfft}},
        dask="parallelized",
        exclude_dims={"kx", "ky"}
    )
    .assign_coords(kx=wavenum, ky=wavenum)
    .sortby("kx").sortby("ky")
)

x = np.linspace(-hres * block_size / 2, hres * block_size / 2, block_size)
patches = patches.rename({"kx": "x", "ky": "y"}).assign_coords(x=x, y=x)

power = abs(F) ** 2

KDIMS = ["kx", "ky"]
threshold = power.quantile(.95, KDIMS).drop_vars("quantile")
W = power.where(power > threshold)
W = W / W.sum(KDIMS)

phase = xarray.apply_ufunc(
    np.unwrap,
    xarray.ufuncs.angle(F),
    input_core_dims=[["time"]],
    output_core_dims=[["time"]],
    dask="parallelized",
) / (2*np.pi)

freq = phase.differentiate("time", datetime_unit="s")
freq_noise_power = freq.rolling(time=9, center=True, min_periods=1).var()
freq_smooth = freq.rolling(time=9, center=True, min_periods=1).mean()
freq_snr = (freq_smooth**2 / freq_noise_power)

k = W.kx + W.ky * 1j
k2 = k ** 2
S0 = (W * k * k.conj()).sum(KDIMS).real
S2 = (W * k2).sum(KDIMS)
anisotropy_mag = abs(S2) / S0
direction = np.exp(1j * xarray.ufuncs.angle(S2) / 2)

# positive / negative projection from phase velocity
m = np.sign((direction * k.conj()).real)
wmean_freq = (W * freq_smooth * m).sum(KDIMS)
wmean_wavevector = (W * k * m).sum(KDIMS)

h = wavenum[1] - wavenum[0]
vg = 1000 * (
    (
        freq_smooth.sel(kx=wmean_wavevector.real+h, ky=wmean_wavevector.imag, method="nearest") - 
        freq_smooth.sel(kx=wmean_wavevector.real-h, ky=wmean_wavevector.imag, method="nearest")
    ) + 
    1j * (
        freq_smooth.sel(kx=wmean_wavevector.real, ky=wmean_wavevector.imag+h, method="nearest") - 
        freq_smooth.sel(kx=wmean_wavevector.real, ky=wmean_wavevector.imag-h, method="nearest")
    )
) / (2 * h)

vp = 1000 * wmean_freq / wmean_wavevector

In [246]:
II = dict(time=10, px=0, py=7)
V = (W.isel(II) * k * m.isel(II))
print(V.sum())
(
    hv.Image(freq_smooth.isel(II).differentiate("ky")).opts(cmap="coolwarm")
    + hv.Image(freq_smooth.isel(II) * m.isel(II))
    * hv.VectorField((V.kx, V.ky, xarray.ufuncs.angle(V).T, abs(V).T), kdims=["kx", "ky"], vdims=["ang", "mag"]).opts(magnitude="mag")
    * hv.Curve(([vg.isel(II).kx, vg.isel(II).kx+vg.isel(II).real/100], [vg.isel(II).ky, vg.isel(II).ky+vg.isel(II).imag/100]))
    # hv.Curve(([0, direction.isel(**II).real/200], [0, direction.isel(**II).imag/200]))
)

<xarray.DataArray ()> Size: 16B
array(0.00289745-0.00151908j)
Coordinates:
    dir      float64 8B 2.628
    lam      float64 8B 315.8
    snr      float64 8B 0.9508
    tau      float64 8B 24.64
    time     datetime64[ns] 8B 2025-01-01T00:10:00
    trial    int64 8B 50
    px       int64 8B -1180
    py       int64 8B -60


In [132]:
KDIMS = ["kx", "ky"]
II = dict(time=10, px=5, py=5)
k = W.kx + W.ky * 1j
k2 = k ** 2
pv = (W.isel(II) * k * freq_smooth.isel(II) / abs(k)**2)

power_image = (
    hv.Image(power.isel(**II))
    .opts(cmap=cc.cm.gouldian)
    .relabel(f"threshold {threshold.isel(**II).item():.3f}")
)

contours = hv.operation.contours(
    power_image,
    levels=[100],
    overlaid=False,
).opts(show_legend=False, cmap=cc.cm.glasbey_bw)

(
    power_image * contours + 
    hv.Image(W.isel(II)/abs(k).where(abs(k)>0)).opts(cmap=cc.cm.gouldian) * contours + 
    hv.Image(freq_smooth.isel(II), vdims="freq")
)

In [None]:
II = dict(time=10, px=5, py=5)
k = W.kx + W.ky * 1j
pv = (W.isel(II) * k * freq_smooth.isel(II) / abs(k)**2).sum(["kx", "ky"])
m = np.sign((k * pv.conj()).real)

pf1 = (abs(freq_smooth).isel(**II) * W.isel(**II)).sum(["kx", "ky"]).item()
pf2 = (freq_smooth.isel(**II) * W.isel(**II) * m).sum(["kx", "ky"]).item()
tf = 1/(60*freq_smooth.tau.item())
er1 = (1 / (60*pf1) - freq_smooth.tau.item()) / freq_smooth.tau.item()
er2 = (1 / (60*pf2) - freq_smooth.tau.item()) / freq_smooth.tau.item()
print(tf, pf1, er1, pf2, er2)

power_image = (
    hv.Image(power.isel(**II))
    .opts(cmap=cc.cm.gouldian)
    .relabel(f"threshold {threshold.isel(**II).item():.3f}")
)

contours = hv.operation.contours(
    power_image,
    levels=[threshold.isel(**II)],
    overlaid=False,
).opts(show_legend=False)

freq_img = (
    hv.Image((freq_smooth-tf).isel(**II))
    .opts(cmap=cc.cm.coolwarm, symmetric=True)
)
(
    power_image * contours + 
    (freq_img * contours).relabel(f"actual: {freq_smooth.tau.item():.2f}, estimate: {1/(60*pf):.2f}")
).opts(shared_axes=False)

0.0006762708765249261 0.0006860379289816468 -0.014236898637979105 0.0006752985390116872 0.001439863196893406


In [97]:
k = W.kx + W.ky * 1j
pv = (W.isel(II) * k * freq_smooth.isel(II) / abs(k)**2).sum(["kx", "ky"])
pvd = pv / abs(pv)
m = np.sign((k * pvd.conj()).real)
hv.Image(m).opts(cmap=cc.cm.coolwarm)

In [103]:
f = W.isel(II)
print((f * freq_smooth.isel(II) * m).sum(["kx", "ky"]))
(
    hv.Image(abs(freq_smooth.isel(II)) - tf).opts(cmap=cc.cm.coolwarm, symmetric=True) + 
    hv.Image(m * freq_smooth.isel(II) - tf).opts(cmap=cc.cm.coolwarm, symmetric=True)
)

<xarray.DataArray ()> Size: 8B
array(0.0006753)
Coordinates:
    dir      float64 8B 2.628
    lam      float64 8B 315.8
    snr      float64 8B 0.9508
    tau      float64 8B 24.64
    time     datetime64[ns] 8B 2025-01-01T00:10:00
    trial    int64 8B 50
    px       int64 8B -380
    py       int64 8B -380
