In [1]:
%load_ext autoreload
%autoreload 3 -p

In [2]:
import numpy as np
from dask.diagnostics import ProgressBar

import xarray
import pandas as pd
import holoviews as hv
hv.extension("bokeh")
from holoviews.plotting.util import colorcet_cmap_to_palette
import colorcet as cc
from src.estimator_analysis import get_results

hv.opts.defaults(
    hv.opts.Scatter(frame_height=400, frame_width=400, tools=["hover"], show_grid=True),
    hv.opts.Image(frame_height=400, frame_width=400, tools=["hover"], show_grid=True, colorbar=True),
    hv.opts.Path(frame_height=300, frame_width=400, tools=["hover"], show_grid=True, line_width=2),
    hv.opts.VectorField(frame_height=400, frame_width=400, show_grid=True, tools=[]),
)

In [None]:
DATA_FILE = "/disk1/tid/users/starr/results/20260206_0639_data_planar.zarr"
RESULTS_FILE = "/disk1/tid/users/starr/results/20260206_0639_results_planar.zarr"

data = xarray.open_zarr(DATA_FILE)
results = get_results(RESULTS_FILE)
results["patch_freq_snr"] = np.log10(results["patch_freq_snr"])
results["rel_speed_error"] = results["speed_error"] / results["spd"]
results["rel_period_error"] = results["period_error"] / results["tau"]
results["rel_wavelength_error"] = results["wavelength_error"] / results["lam"]
results["e1"] = (results["S0"] + abs(results["S2"])) / 2
results["e2"] = (results["S0"] - abs(results["S2"])) / 2

with ProgressBar():
    results.compute()

r = results.median(["time"])

[########################################] | 100% Completed | 11.93 s


In [163]:
df = r.to_dataframe()
df2 = results.to_dataframe()
df2 = df2[~df2.isnull().any(axis=1)]

  return self.func(*new_argspec, **kwargs)
  return self.func(*new_argspec, **kwargs)


In [39]:
Y_vars = ["rel_period_error", "rel_wavelength_error"]
P = None
for Y in Y_vars:
    y = np.log1p(df2[Y])
    hist = np.histogram(y, bins=120, density=True)
    p = (
        hv.Histogram(hist, kdims=Y, vdims=Y+"density").opts(
            frame_width=350,
            frame_height=350,
            show_grid=True,
            show_legend=False,
            tools=["hover"],
        ).relabel(f"{np.mean(y):.2f}, {np.median(y):.2f}")
    )
    
    if P is None:
        P = p
    else:
        P += p
P.cols(3)

In [7]:
X = "rel_period_error"
Y = "rel_wavelength_error"

xlim = np.quantile(df2[X], [.001, .999])
ylim = np.quantile(df2[Y], [.001, .999])
m = (df2[X] >= xlim[0]) & (df2[X] <= xlim[1]) & (df2[Y] >= ylim[0]) & (df2[Y] <= ylim[1])
hv.HexTiles(df2[m], kdims=[X, Y]).opts(
    frame_width=500,
    frame_height=500,
    show_grid=True,
    show_legend=False,
    tools=["hover"],
    gridsize=100,
    min_count=100,
    colorbar=True,
)

In [None]:
from scipy.signal.windows import kaiser
from scipy.fft import fft2, fftfreq

trial = 50
D = data.sel(trial=trial).compute()

block_size = 32
step_size = 8
hres = 20
Nfft = 128
edges = block_size // (2 * step_size)
window = kaiser(block_size, 5)
window = np.outer(window, window) / np.sum(window)

patches = (
    D.image
    .rolling(y=block_size, x=block_size, center=True)
    .construct(x="kx", y="ky", stride=step_size)
    .pipe(lambda x: x * window)
    .isel(x=slice(edges, -edges), y=slice(edges, -edges))
    .rename({"x": "px", "y": "py"})
)
patches = (patches - patches.mean(["kx", "ky"])) / patches.std(["kx", "ky"])
wavenum = fftfreq(Nfft, hres)

F = (
    xarray.apply_ufunc(
        lambda x: fft2(x * window, s=(Nfft, Nfft)),
        patches,
        input_core_dims=[["kx", "ky"]],
        output_core_dims=[["kx", "ky"]],
        output_dtypes=[np.complex128],
        dask_gufunc_kwargs={"output_sizes": {"kx": Nfft, "ky": Nfft}},
        dask="parallelized",
        exclude_dims={"kx", "ky"}
    )
    .assign_coords(kx=wavenum, ky=wavenum)
    .sortby("kx").sortby("ky")
)

x = np.linspace(-hres * block_size / 2, hres * block_size / 2, block_size)
patches = patches.rename({"kx": "x", "ky": "y"}).assign_coords(x=x, y=x)

power = abs(F) ** 2

KDIMS = ["kx", "ky"]
threshold = power.quantile(.95, KDIMS).drop_vars("quantile")
W = power.where(power > threshold)
W = W / W.sum(KDIMS)

phase = xarray.apply_ufunc(
    np.unwrap,
    xarray.ufuncs.angle(F),
    input_core_dims=[["time"]],
    output_core_dims=[["time"]],
    dask="parallelized",
) / (2*np.pi)

freq = phase.differentiate("time", datetime_unit="s")
freq_noise_power = freq.rolling(time=9, center=True, min_periods=1).var()
freq_smooth = freq.rolling(time=9, center=True, min_periods=1).mean()
freq_snr = (freq_smooth**2 / freq_noise_power)

k = W.kx + W.ky * 1j
k2 = k ** 2
S0 = (W * abs(k)**2).sum(KDIMS)
S2 = (W * k2).sum(KDIMS)
anisotropy_mag = abs(S2) / S0
direction = np.exp(1j * xarray.ufuncs.angle(S2) / 2)

# positive / negative projection from phase velocity
m = np.sign((direction * k.conj()).real)
wmean_freq = (W * freq_smooth * m).sum(KDIMS)
wmean_wavevector = (W * k * m).sum(KDIMS)

h = wavenum[1] - wavenum[0]
vg = 1000 * (
    (
        freq_smooth.sel(kx=wmean_wavevector.real+h, ky=wmean_wavevector.imag, method="nearest") - 
        freq_smooth.sel(kx=wmean_wavevector.real-h, ky=wmean_wavevector.imag, method="nearest")
    ) + 
    1j * (
        freq_smooth.sel(kx=wmean_wavevector.real, ky=wmean_wavevector.imag+h, method="nearest") - 
        freq_smooth.sel(kx=wmean_wavevector.real, ky=wmean_wavevector.imag-h, method="nearest")
    )
) / (2 * h)

vp = 1000 * wmean_freq / wmean_wavevector

In [160]:
II = dict(px=10, py=10, time=30)
K = k2.isel(kx=slice(None, None, 2), ky=slice(None, None, 2))
R = (1000*(k.conj() * direction.isel(II) * np.exp(.5j * np.pi)).real)**2
print((R*W.isel(II)).sum(KDIMS))
(
    hv.Image(power.isel(II)).opts(cmap=cc.cm.linear_worb_100_25_c53)
    * hv.VectorField((K.kx, K.ky, xarray.ufuncs.angle(K).T, abs(K / abs(K)).T), kdims=["kx", "ky"], vdims=["a", "m"]).opts(magnitude="m")
    + hv.Image(R * W.isel(II))
)

<xarray.DataArray ()> Size: 8B
array(2.06812483)
Coordinates:
    dir      float64 8B 0.8579
    lam      float64 8B 159.7
    snr      float64 8B -0.1816
    tau      float64 8B 21.82
    trial    int64 8B 50
    time     datetime64[ns] 8B 2025-01-01T00:30:00
    px       int64 8B 420
    py       int64 8B 420


In [179]:
np.std(np.random.randn(100)+10)

np.float64(0.9398978731818133)

In [184]:
hv.Scatter(df.assign(x=1/np.sqrt(df["e2"])), kdims="snr", vdims=["x", "lam", "tau"]).opts(color="lam", colorbar=True)