In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
from dask.diagnostics import ProgressBar

import xarray
import pandas as pd
import holoviews as hv
import colorcet as cc
from holoviews.plotting.util import colorcet_cmap_to_palette
from src.estimator_analysis import get_results

hv.opts.defaults(
    hv.opts.Scatter(frame_height=400, frame_width=400, tools=["hover"], show_grid=True),
    hv.opts.Image(frame_height=400, frame_width=400, tools=["hover"], show_grid=True, colorbar=True),
    hv.opts.Path(frame_height=300, frame_width=400, tools=["hover"], show_grid=True, line_width=2),
    hv.opts.VectorField(frame_height=500, frame_width=500, show_grid=True, tools=[]),
)

In [None]:
DATA_FILE = "/disk1/tid/users/starr/results/data5.zarr"
RESULTS_FILE = "/disk1/tid/users/starr/results/results5_patch.zarr"
data = xarray.open_zarr(DATA_FILE, chunks=dict(px=-1, py=-1, time=-1, trial=1))
results = get_results(RESULTS_FILE, center=(100, 100))

  return self.func(*new_argspec)


In [3]:
m = (results.snr > 2) & (results.lam > 200) & (results.tau < 30)
filtered_results = results.isel(trial=m.compute())
hv.Scatter((filtered_results.trial.values, filtered_results.snr.values, filtered_results.lam.values, filtered_results.tau.values), kdims="trial", vdims=["snr", "lam", "tau"]).opts(color="lam", size=10)

In [4]:
trial = 366
T = 10
D = data.sel(trial=trial).compute()
R = filtered_results.sel(trial=trial).compute()

In [5]:
1000 * R.lam / (R.tau * 60)

In [6]:
from scipy.signal.windows import kaiser
from scipy.fft import fft2, fftfreq

block_size = 32
step_size = 8
hres = 20
Nfft = 128
edges = block_size // (2 * step_size)
window = kaiser(block_size, 5)
window = np.outer(window, window) / np.sum(window)

patches = (
    D.image
    .rolling(y=block_size, x=block_size, center=True)
    .construct(x="kx", y="ky", stride=step_size)
    .pipe(lambda x: x * window)
    .isel(x=slice(edges, -edges), y=slice(edges, -edges))
    .rename({"x": "px", "y": "py"})
)
patches = (patches - patches.mean(["kx", "ky"])) / patches.std(["kx", "ky"])
wavenum = fftfreq(Nfft, hres)

F = (
    xarray.apply_ufunc(
        lambda x: fft2(x * window, s=(Nfft, Nfft)),
        patches,
        input_core_dims=[["kx", "ky"]],
        output_core_dims=[["kx", "ky"]],
        output_dtypes=[np.complex128],
        dask_gufunc_kwargs={"output_sizes": {"kx": Nfft, "ky": Nfft}},
        dask="parallelized",
        exclude_dims={"kx", "ky"}
    )
    .assign_coords(kx=wavenum, ky=wavenum)
    .sortby("kx").sortby("ky")
)

x = np.linspace(-hres * block_size / 2, hres * block_size / 2, block_size)
patches = patches.rename({"kx": "x", "ky": "y"}).assign_coords(x=x, y=x)

power = abs(F) ** 2

phase = xarray.apply_ufunc(
    np.unwrap,
    xarray.ufuncs.angle(F),
    input_core_dims=[["time"]],
    output_core_dims=[["time"]],
    dask="parallelized",
) / (2*np.pi)

freq = phase.differentiate("time", datetime_unit="s")
freq_noise_power = freq.rolling(time=9, center=True, min_periods=1).var()
freq_smooth = freq.rolling(time=9, center=True, min_periods=1).mean()

In [None]:
threshold = power.quantile(.95, ["kx", "ky"]).drop_vars("quantile")
W = power.where(power > threshold)
W /= W.sum(["kx", "ky"])
k = W.kx + W.ky * 1j
pv = (W * k * freq / abs(k)**2)
patch_pv = pv.sum(["kx", "ky"])

PDIM = dict(px=3, py=-1)

img = D.isel(time=T).image
patch = patches.isel(time=T, **PDIM)
pwr = power.isel(time=T, **PDIM)
pwr_T = threshold.isel(time=T, **PDIM)
w = W.isel(time=T, **PDIM)
f = freq_smooth.isel(time=T, **PDIM)
(
    (
        hv.Image(img).opts(cmap=cc.cm.diverging_bwr_55_98_c37, symmetric=True) * 
        hv.Box(patch.px.values, patch.py.values, 640).opts(line_width=2, line_dash="dashed")
    ) + 
    hv.Image(patch).opts(cmap=cc.cm.diverging_bwr_55_98_c37, symmetric=True) +
    hv.operation.contours(hv.Image(pwr).opts(cmap=cc.cm.linear_protanopic_deuteranopic_kbjyw_5_95_c25), levels=[pwr_T], overlaid=True).opts(show_legend=False) +
    hv.Image(w).opts(cmap=cc.cm.gouldian) +
    hv.Image(w*f).opts(cmap=cc.cm.bky, symmetric=True)
).opts(shared_axes=False).cols(2)

In [None]:
s = {**PDIM, **{k: v.item() for k,v in w.argmax(["kx", "ky"]).items()}}
(
    hv.Spread((np.arange(freq.time.shape[0]), freq_smooth.isel(s), xarray.ufuncs.sqrt(freq_noise_power.isel(s))), kdims="time", vdims=["freq", "freq_err"]) *
    hv.Scatter((np.arange(freq.time.shape[0]), freq.isel(s)), kdims="time", vdims="freq").opts(color="k") * 
    hv.Path((np.arange(freq.time.shape[0]), freq_smooth.isel(s)), kdims=["time", "freq"]).opts(line_color="r")
)

In [None]:
k = W.kx + W.ky * 1j
pv = (-1 * k * freq / abs(k)**2)
ppv = pv.isel(time=T, **PDIM) * 1000
ppv = ppv.where((abs(k)>5.0e-4) & (w > 1.0e-4))

lims = [abs(ppv).min().item(), abs(ppv).max().item()]
print(lims)
(
    hv.Image(
        (ppv.kx, ppv.ky, w.values),
        kdims=["kx", "ky"],
    ).opts(
        cmap=cc.cm.linear_worb_100_25_c53,
        colorbar=True
    ) *
    hv.VectorField(
        (ppv.kx, ppv.ky, xarray.ufuncs.angle(ppv).T, abs(ppv).T),
        kdims=["kx", "ky"],
        vdims=["arg", "mag"]
    ).opts(
        # magnitude="mag"
        magnitude=hv.dim("mag").lognorm(lims)/1500,
        rescale_lengths=False,
    ).redim.range(kx=(-.0075, .0075), ky=(-.0075, .0075))
)

[43.955227966983216, 1336.4885919153185]


In [None]:
# k = power1.kx + 1j * power1.ky
# k2 = k * k
# weight1 = power1.where(power1 > power1.quantile(.95)).drop_vars("quantile")
# weight1 /= weight1.sum()
# S0_1 = (weight1 * k * k.conj()).sum().real
# S2_1 = (weight1 * k2).sum()
# anisotropy_1 = abs(S2_1) / S0_1

# k = power2.kx + 1j * power2.ky
# k2 = k * k
# weight2 = power2.where(power2 > power2.quantile(.95)).drop_vars("quantile")
# weight2 /= weight2.sum()
# S0_2 = (weight2 * k * k.conj()).sum().real
# S2_2 = (weight2 * k2).sum()
# anisotropy_2 = abs(S2_2) / S0_2

# print(S0_1.item(), S2_1.item(), anisotropy_1.item())
# print(S0_2.item(), S2_2.item(), anisotropy_2.item())
# # hv.Image(weight2).opts(cmap=cc.cm.linear_protanopic_deuteranopic_kbjyw_5_95_c25, colorbar=True)