In [None]:
# parameters, export
seed = 0

jax_enable_x64 = False
r = 10
beta = 0.0
alpha_scale = 1.0
prior_pi = 0.95
ell = 0.5

In [None]:
import jax

jax.config.update("jax_log_compiles", False)
jax.config.update("jax_enable_x64", jax_enable_x64)

master_key = jax.random.PRNGKey(seed)

In [None]:
import jax.numpy as jnp

from iklp.hyperparams import (
    ARPrior,
    pi_kappa_hyperparameters,
    solve_for_alpha,
)
from iklp.mercer import psd_svd_fixed
from utils.jax import maybe32
from utils.openglot import OpenGlotI
from utils.audio import frame_signal
from iklp.mercer_op import backend
from utils import time_this
from iklp.run import vi_run_criterion_batched

from tqdm import tqdm
import numpy as np


In [None]:
P = 9
f0 = np.arange(100, 360 + 1, 10)
I = len(f0)

target_fs = 8000
M = 1024  # frame length
hop = 80

max_vi_iter = 50
batch_size = 128 if jax_enable_x64 else 256

In [None]:
dt = 1.0 / target_fs
t = np.arange(M) * dt
tau = t[:, None] - t[None, :]
T = (1 / f0)[:, None, None]
K = np.exp(-2 * (np.sin(np.pi * tau / T)) ** 2 / (ell**2))  # (I, M, M)
Phi, energy = psd_svd_fixed(K, rank=r)

print(f"Energy captured at rank={r}:", energy)

In [None]:
arprior = ARPrior.yoshii_lambda(P)


In [None]:
alpha = solve_for_alpha(I) * alpha_scale
kappa = 1.0

h = pi_kappa_hyperparameters(
    maybe32(Phi),
    pi=maybe32(prior_pi),
    kappa=maybe32(kappa),
    alpha=maybe32(alpha),
    arprior=arprior,
    num_metrics_samples=1,
    num_vi_iters=max_vi_iter,
    beta=maybe32(beta),
)

del K, Phi

print("Phi shape:", h.Phi.shape)  # (I, M, r)
print("Phi dtype:", h.Phi.dtype)
print("Mercer operator backend:", backend(h))

In [None]:
def all_runs(verbose=False):
    for wav_file in tqdm(OpenGlotI.wav_files()):
        vowel, modality, true_pitch = OpenGlotI.parse_wav(wav_file)
        true_formants = OpenGlotI.true_resonance_frequencies[vowel]
        x_full, gf_full, original_fs = OpenGlotI.read_wav(
            wav_file, target_fs, verbose=verbose
        )

        x_frames = frame_signal(x_full, M, hop)  # ((n_frames, frame_len)
        gf_frames = frame_signal(gf_full, M, hop)

        x_frames = maybe32(x_frames)
        gf_frames = maybe32(gf_frames)

        for frame_index, (x, gf) in enumerate(zip(x_frames, gf_frames)):
            for restart_index in range(h.num_vi_restarts):
                yield {
                    "wav_file": wav_file,
                    "original_fs": original_fs,
                    "target_fs": target_fs,
                    "vowel": vowel,
                    "modality": modality,
                    "true_pitch": true_pitch,
                    "true_formants": true_formants,
                    "frame_index": frame_index,
                    "num_frames": x_frames.shape[0],
                    "restart_index": restart_index,
                    "x": x,
                    "gf": gf,
                    "x_frames": x_frames,
                    "gf_frames": gf_frames,
                }


runs = list(all_runs())

print("Total runs:", len(runs))


In [None]:
x = jnp.vstack([run["x"] for run in runs])

print("Data shape:", x.shape)
print("Data dtype:", x.dtype)

In [None]:
with time_this() as elapsed:
    metrics_tree, unpack = vi_run_criterion_batched(
        master_key, x, h, batch_size=batch_size, verbose=True
    )

metrics_list = list(unpack(metrics_tree))

In [None]:
# export
# Mean energy captured by the chosen rank `r`
mean_energy = np.mean(energy)

# This includes compilation for the shapes of the first and last batch, which are O(1) min
time_per_iter = elapsed.walltime / metrics_tree.i.sum()

results = [
    OpenGlotI.post_process_run(run, metrics, f0)
    for run, metrics in tqdm(zip(runs, metrics_list))
]

In [None]:
# Plot best u(t) fit
i = int(np.nanargmin([r["gf_aligned_nrmse"] for r in results]))
print(i)

OpenGlotI.plot_run(runs[i], metrics_list[i], f0, retain_plots=True)

In [None]:
# Plot best pitch fit
i = int(np.nanargmin([r["pitch_wrmse"] for r in results]))
print(i)

OpenGlotI.plot_run(runs[i], metrics_list[i], f0, retain_plots=True)

In [None]:
# Plot best formant fit
i = int(np.nanargmin([r["formant_rmse"] for r in results]))
print(i)

OpenGlotI.plot_run(runs[i], metrics_list[i], f0, retain_plots=True)