In [None]:
seed = 0
wav_file = (
    "/home/marnix/thesis/data/OPENGLOT/RepositoryI/Vowel_O/O_normal_320Hz.wav"
)

noise_floor_db = -40.0

jax_enable_x64 = True
jax_platform_name = "gpu"
batch_size = 1
num_metrics_samples = 1

In [2]:
from time import time

import jax
import jax.numpy as jnp


# First to set config flags wins!
jax.config.update("jax_enable_x64", jax_enable_x64)
jax.config.update("jax_platform_name", jax_platform_name)

import numpy as np
import soundfile as sf

from iklp.run import vi_frames_batched
from iklp.hyperparams import Hyperparams
from iklp.periodic import periodic_kernel_phi
from utils.audio import frame_signal, resample

# Print out jax config
print("default float:", jnp.array(0.0).dtype)
print("backend:", jnp.array(0.0).device)



default float: float64
backend: cuda:0


In [3]:
vowel, modality, true_pitch = (
    wav_file.split("/")[-1].split(".")[-2].split("_")[:3]
)
vowel = vowel.lower()
modality = modality.lower()
true_pitch = int(true_pitch.lower()[:-2])  # Remove 'Hz' from the pitch string

In [4]:
# Use same parameters as in OPENGLOT and Yoshii
P = 9
I = 400
f0_min = 70  # Typical lower bound (Nielsen 2013)
f0_max = 400

# Nudge
initial_pitchedness = 0.99

# Adjust these together
target_sr = 8000
frame_len = 1024
hop = 80
# batch_size = 2 # hyperparameter

In [5]:
audio, sr_in = sf.read(
    wav_file, always_2d=False, dtype="float64" if jax_enable_x64 else "float32"
)

# Split channels
x = audio[:, 0]
dgf = audio[:, 1]

x = resample(x, sr_in, target_sr)
dgf = resample(dgf, sr_in, target_sr)

# Normalize data to unit power
scale = 1 / np.sqrt(np.mean(x**2))
x = x * scale
dgf = dgf * scale

frames = frame_signal(x, frame_len, hop)  # ((n_frames, frame_len)
frames = jnp.array(frames)

print(
    f"→ Loaded {wav_file} ({len(x)} samples, {sr_in} Hz), {frames.shape[0]} frames of {frame_len} samples each"
)




→ Loaded /home/marnix/thesis/data/OPENGLOT/RepositoryI/Vowel_O/O_normal_320Hz.wav (1600 samples, 8000 Hz), 8 frames of 1024 samples each


In [None]:
from utils.jax import maybe32


f0, Phi = periodic_kernel_phi(
    I=I,
    M=frame_len,
    fs=target_sr,
    f0_min=f0_min,
    f0_max=f0_max,
    noise_floor_db=noise_floor_db,
)

aw = initial_pitchedness / (1 - initial_pitchedness)


h = Hyperparams(
    Phi,
    P=P,
    aw=maybe32(aw),
    num_vi_restarts=5,
    num_vi_iters=30,
    num_metrics_samples=num_metrics_samples,
)

master_key = jax.random.PRNGKey(seed)

In [12]:
t0 = time()
metrics = vi_frames_batched(master_key, frames, h, batch_size=batch_size)
# Materialize
jax.block_until_ready(metrics)
walltime = time() - t0

print(f"Metrics shape: {metrics.elbo.shape}")
print(f"Walltime for VI: {walltime:.2f} seconds")

total_iters = np.prod(metrics.elbo.shape[:3])

print(f"Total iterations: {total_iters}")
print(f"Time per iteration: {walltime / total_iters:.2f} seconds")

Metrics shape: (2, 2, 31)
Walltime for VI: 142.87 seconds
Total iterations: 124
Time per iteration: 1.15 seconds


In [8]:
time_per_iter = walltime / total_iters
I, M, r = h.Phi.shape

In [9]:
import scrapbook as sb
import numpy as np


def _to_py(x):
    if hasattr(x, "tolist"):
        try:
            return x.tolist()
        except Exception:
            pass
    if isinstance(x, np.generic):
        return x.item()
    return x


def _walk(x):
    x = _to_py(x)
    if isinstance(x, dict):
        return {k: _walk(v) for k, v in x.items()}
    if isinstance(x, (list, tuple)):
        return [_walk(v) for v in x]
    return x


# glue exports
sb.glue("I", _walk(I))
sb.glue("M", _walk(M))
sb.glue("batch_size", _walk(batch_size))
sb.glue("jax_enable_x64", _walk(jax_enable_x64))
sb.glue("jax_platform_name", _walk(jax_platform_name))
sb.glue("modality", _walk(modality))
sb.glue("noise_floor_db", _walk(noise_floor_db))
sb.glue("r", _walk(r))
sb.glue("seed", _walk(seed))
sb.glue("time_per_iter", _walk(time_per_iter))
sb.glue("true_pitch", _walk(true_pitch))
sb.glue("vowel", _walk(vowel))
sb.glue("wav_file", _walk(wav_file))