In [8]:
import os
from pprint import pprint
from time import time

import jax
import jax.numpy as jnp
import numpy as np
import soundfile as sf

from iklp.hyperparams import Hyperparams
from iklp.periodic import periodic_kernel_phi
from iklp.state import compute_expectations, init_state
from iklp.vi import compute_elbo_bound, vi_step
from utils.audio import frame_signal, resample
from utils.stats import (
    average_list_of_dicts,
    kl_div,
    normalize_weights,
    weighted_pitch_error,
)

In [9]:
seed = 0
wav_file = (
    os.environ["PROJECT_DATA_PATH"]
    + "/OPENGLOT/RepositoryI/Vowel_AE/AE_whispery_260Hz.wav"
)

In [10]:
# From OPENGLOT paper Table 1
true_resonance_frequencies = {
    "a": [730, 1090, 2440, 3500],
    "e": [530, 1840, 2480, 3500],
    "i": [390, 1990, 2550, 3500],
    "o": [570, 840, 2410, 3500],
    "u": [440, 1020, 2240, 3500],
    "ae": [660, 1720, 2410, 3500],
}

In [11]:
vowel, modality, true_pitch = (
    wav_file.split("/")[-1].split(".")[-2].split("_")[:3]
)
vowel = vowel.lower()
modality = modality.lower()
true_pitch = int(true_pitch.lower()[:-2])  # Remove 'Hz' from the pitch string

f1, f2, f3, f4 = true_resonance_frequencies[vowel]

In [12]:
# Use same parameters as in OPENGLOT and Yoshii
P = 9
I = 400
f0_min = 70  # Typical lower bound (Nielsen 2013)
f0_max = 400

# Adjust these together
target_sr = 8000
frame_len = 1024
hop = 80

noise_floor_db = -60.0
max_iter = 50
criterion = 0.0001

verbose = True

In [13]:
audio, sr_in = sf.read(wav_file, always_2d=False, dtype="float64")

# Split channels
x = audio[:, 0]
dgf = audio[:, 1]

x = resample(x, sr_in, target_sr)
dgf = resample(dgf, sr_in, target_sr)

# Normalize data to unit power
scale = 1 / np.sqrt(np.mean(x**2))
x = x * scale
dgf = dgf * scale

print(f"→ Loaded {wav_file} ({len(x)} samples, {sr_in} Hz)")

→ Loaded /home/marnix/thesis/data/OPENGLOT/RepositoryI/Vowel_AE/AE_whispery_260Hz.wav (1600 samples, 8000 Hz)


In [14]:
f0, Phi = periodic_kernel_phi(
    I=I,
    M=frame_len,
    fs=target_sr,
    f0_min=f0_min,
    f0_max=f0_max,
    noise_floor_db=noise_floor_db,
)

h = Hyperparams(Phi, P=P)

vi_step = jax.jit(vi_step)
compute_elbo_bound = jax.jit(compute_elbo_bound)
compute_expectations = jax.jit(compute_expectations)

In [15]:
true_posterior = np.zeros_like(f0)
q = np.abs(f0 - true_pitch).argmin().item()
true_posterior[q] = 1.0

uniform_prior = np.ones_like(f0) / len(f0)


def compute_frame_stats(frame, state):
    E = compute_expectations(state)

    # Characterize the posterior of theta (fundamental frequency)
    theta_posterior = normalize_weights(E.theta)

    compression = np.exp(kl_div(theta_posterior, uniform_prior))
    fitness = np.exp(kl_div(true_posterior, theta_posterior))

    # Fundamental frequency error
    pitch_wrmse, pitch_wmae = weighted_pitch_error(f0, E.theta, true_pitch)

    # Power and pitchedness
    frame_power = np.mean(frame**2)
    nu_e = E.nu_e
    nu_w = E.nu_w
    pitchedness = nu_w / (nu_e + nu_w)

    # AR
    a = state.xi.delta_a

    stats = {
        "compression": compression.item(),
        "fitness": fitness.item(),
        "pitch_wrmse": pitch_wrmse.item(),
        "pitch_wmae": pitch_wmae.item(),
        "frame_power": frame_power.item(),
        "nu_e": E.nu_e.item(),
        "nu_w": E.nu_w.item(),
        "pitchedness": pitchedness.item(),
        "a": a.tolist(),
    }

    return stats

In [None]:
def process_frame(key, frame_index, frame, h):
    state = init_state(key, frame, h)
    score = -jnp.inf

    started = time()
    for i in range(max_iter):
        state = vi_step(state)
        lastscore = score
        score = compute_elbo_bound(state)

        if i == 0:
            improvement = 1.0
        else:
            improvement = (score - lastscore) / jnp.abs(lastscore)

        if verbose:
            print(
                f"Frame {frame_index}, iteration {i}: "
                f"bound = {score:.2f} "
                f"({improvement:+.5f} improvement)"
            )

        if improvement < 0.0:
            break
        if improvement < criterion:
            break
        if jnp.isnan(improvement) and i > 0:
            break

    walltime = time() - started

    stats = compute_frame_stats(frame, state)

    stats["frame_index"] = frame_index
    stats["num_iters"] = i + 1
    stats["time_per_iter"] = walltime / (i + 1)
    stats["score"] = score.item()

    if verbose:
        pprint(stats)

    return stats

In [17]:
frames = frame_signal(x, frame_len, hop)

master_key = jax.random.PRNGKey(seed)
keys = jax.random.split(master_key, len(frames))

In [18]:
frame_stats = [
    process_frame(key, idx, frame, h)
    for idx, (key, frame) in enumerate(zip(keys, frames))
]

Frame 0, iteration 0: bound = -2102.25 (+1.00000 improvement)
Frame 0, iteration 1: bound = -1401.21 (+0.33347 improvement)
{'a': [1.3387344266660306,
       -1.176715610272982,
       0.8223889009854591,
       -0.36688087250831325,
       -0.138149767985529,
       0.048964262169871424,
       -0.1371183785097248,
       0.015289518621954384,
       0.056303931195212105],
 'compression': 1.4046117057115843,
 'fitness': 35.60015358262199,
 'frame_index': 0,
 'frame_power': 1.0020824189947481,
 'nu_e': 0.28164945384151246,
 'nu_w': 0.428844674295871,
 'num_iters': 2,
 'pitch_wmae': 113.13535718784249,
 'pitch_wrmse': 92.76397132018236,
 'pitchedness': 0.6035865143884035,
 'score': -1401.2102037214718,
 'time_per_iter': 9.164981603622437}
Frame 1, iteration 0: bound = -2077.36 (+1.00000 improvement)
Frame 1, iteration 1: bound = -1365.65 (+0.34260 improvement)
{'a': [1.3176107577432608,
       -1.1995811455574148,
       0.8655946750068139,
       -0.423915117209073,
       -0.071677286

In [19]:
file_stats = average_list_of_dicts(frame_stats)

if verbose:
    print("File-level statistics:")
    pprint(file_stats)

File-level statistics:
{'a': 0.048837997570712166,
 'compression': 1.4146585233168623,
 'fitness': 37.874198495650376,
 'frame_index': 3.5,
 'frame_power': 0.9969126295327071,
 'nu_e': 0.2575138458919407,
 'nu_w': 0.5012687451415266,
 'num_iters': 2.0,
 'pitch_wmae': 116.21796674883592,
 'pitch_wrmse': 96.262738770451,
 'pitchedness': 0.6584877782067893,
 'score': -1382.5378845769117,
 'time_per_iter': 3.1737114936113358}
