# Estimating $p(T)$ from APLAWD

- APLAWD database [@Lindsey1987]. British English.

- APLAWD GCI markings [@Serwy2017]. Very high-quality, hand-corrected.

- Time shift of WAV vs. EGG is ~ 0.95 msec [@Naylor2007]. **OK: Implemented.**

## Note about jitter

We model the 'true' pitch periods by a GP with $\sigma^2_n$ (i.e., `noise_sigma**2`) term which, when fitted, predicts unrealistically high voice jitter. This is because it has picked up all kinds of other effects in the data, working as a 'shock absorber' to keep the other GP parameters unharmed. We acknowledge this and during inference of the latent GP function replace this $\sigma_n^2$ noise term by zero. We still have inherent jitter in the process which we calculate below. It is likely that the jitter has been learned to some degree by the kernel because the `Matern32Kernel` is preferred over `Matern52Kernel` and `SqExponentialKernel`.

Note: we could circumvent this by using a Student-T process instead of a GP, which would absorp these shocks (i.e., outliers) for us, resulting in a more realistic $\sigma_n^2$ value which would relative more directly to jitter.

## Can we use `TIMIT-voiced`?

No.

The TIMIT-voiced database does not have GCI information; the frames nearly constant-length and give only voiced/unvoiced information.
In addition, I could not find a reference, and I think these values have been auto-generated and not manually verified.

In [None]:
%run init.ipy

from lib import aplawd
from lib import praat
from dgf import bijectors
from dgf import isokernels
from dgf import core
from dgf import constants
from dgf.prior import period

import dynesty
import pandas as pd
import warnings
import itertools
import scipy.stats
import multiprocessing

## Plot samples from APLAWD

The distribution of the unconstrained (i.e., transformed) $z = b^{-1}(T)$ variable is bimodal, pointing to the sex of the speakers.

In [None]:
recordings = aplawd.APLAWD(__datadir__('APLAWDW/dataset'))
markings = aplawd.APLAWD_Markings(__datadir__('APLAWDW/markings/aplawd_gci'))

def plot_recording_and_degg(recordings, markings, key):
    """Adapted from https://github.com/serwy/aplawdw/blob/master/demo_001.py"""
    aplawd_db = recordings
    markings_db = markings

    key = np.random.choice(aplawd_db.keys())

    recording = aplawd_db.load_shifted(key)
    recording_gci = markings_db.load(recording.name)

    ax = plt.subplot(211)
    t = np.arange(len(recording.s)) / recording.fs
    plt.plot(t, recording.s, alpha=0.5)
    plt.ylabel('speech')
    plt.title('APLAWD waveform: %s' % recording.name)

    plt.subplot(212, sharex=ax)
    t = np.arange(len(recording.d)) / recording.fs

    plt.plot(t, recording.d, alpha=0.5, label='DEGG')

    plt.plot(t[recording_gci], 0*recording_gci, 'o', ms=5, alpha=0.5,
             label='reference markings')

    plt.legend(loc='lower right', fancybox=True, framealpha=0.5)
    plt.ylabel('DEGG')
    plt.xlabel('time (s)')

    plt.tight_layout()

plot_recording_and_degg(recordings, markings, recordings.random_key())

## Illustrate the data gathering process: `period.yield_training_pairs()`

We pair the manually checked markings with Praat's pulse estimates.
Each recording is split into groups of voiced markings and those groups are matched and aligned with Praat's pulse estimates.

In [None]:
key = recordings.random_key()

k, m = period.load_recording_and_markers(recordings, markings, key)
pulses = praat.get_pulses(k.s, k.fs)
f = 1000/k.fs

title(f'Full waveform: {key}')
plot(k.s)
for p in m: axvline(p, color='black')
for p in pulses: axvline(p, color='pink')
#xlim(1800, 2300)
show()

figure()
title('Diff of GCI markers')
plot(diff(m)*f)
plot(diff(pulses)*f)
ylabel('msec')
xlabel('index')
show()

MIN_NUM_PERIODS = 3

voiced_groups = period.split_markings_into_voiced_groups(
    m, k.fs, constants.MAX_PERIOD_LENGTH_MSEC, MIN_NUM_PERIODS
)

for group in voiced_groups:
    if len(group) <= MIN_NUM_PERIODS + 1: continue

    group_periods = np.diff(group)*f # msec
    display(group_periods)
    
    group, group_praat = period.align_and_intersect(group, pulses)
    assert len(group) == len(group_praat)
    if len(group) <= MIN_NUM_PERIODS + 1:
        continue

    praat_periods = np.diff(group_praat) / k.fs * 1000 # msec
    if np.any(praat_periods > constants.MAX_PERIOD_LENGTH_MSEC):
        # Discard this and continue; we assume user will never accept
        # such Praat estimates so we don't want to model this case.
        warnings.warn(
            f'Discarded voiced group of GCIs because one of the synced '
            f'Praat periods is longer than {constants.MAX_PERIOD_LENGTH_MSEC} msec'
        )
        continue
    
    figure()
    title('Waveform of local voiced group with GCI estimates')
    t = np.arange(group[0], group[-1])
    plot(t*f, k.s[group[0] : group[-1]])
    for p in group: axvline(p*f, color='black')
    for p in group_praat: axvline(p*f, color='pink')
    xlabel('time [msec]')
    
    show()
    
    figure()
    title('Pitch period length trajectory')
    plot(np.diff(group) * f)
    plot(np.diff(group_praat) * f)
    xlabel('pitch period index')
    show()

## Plot marginal statistics

In [None]:
training_pairs, training_pairs_z = period.get_aplawd_training_pairs()

# Check if everything is within the bijector bounds
assert not np.any(np.concatenate([np.isnan(x) | np.isnan(y) for x, y in training_pairs_z]))

def plot_marginal(training_pairs, lab):
    true_marginal = []
    praat_marginal = []

    for true_group, praat_group in training_pairs:
        true_marginal.append(true_group)
        praat_marginal.append(praat_group)

    true_marginal = np.concatenate(true_marginal)
    praat_marginal = np.concatenate(praat_marginal)

    hist([true_marginal, praat_marginal], bins=50)
    xlabel(lab)
    
    return true_marginal, praat_marginal

In [None]:
len_data_points = np.array([len(pair[0]) for pair in training_pairs])
hist(len_data_points, bins=50)
pd.DataFrame(len_data_points).describe()

In [None]:
true_marginal, praat_marginal = plot_marginal(training_pairs, 'Pitch period [msec]')
pd.DataFrame({'true': true_marginal, 'Praat': praat_marginal}).describe()

In [None]:
true_marginal_z, praat_marginal_z = plot_marginal(training_pairs_z, 'z')
pd.DataFrame({'true': true_marginal_z, 'Praat': praat_marginal_z}).describe()

## Fit GP model

The GP model for the true GCI markings is independent of the Praat observation model, so we train them separately.

The observation noise $\sigma_n^2$ (`noise_sigma**2`) relates to voice jitter.

The `Matern32Kernel` has the highest posterior probability. The fitted source envelope lengthscales $\lambda$ differ quite strongly between the kernels, with highest evidence for $\lambda \approx 10$.

The `Matern12Kernel` did converge but has a very long lengthscale, but low evidence. **So there is a preference for some roughness, but not too much: this indicates the learning of a jitter component.**

The Praat observation error is fitted below.

In [None]:
# The nested sampling inference for these kernels takes around 5 +/- 0.5 hrs
for kernel_name in ('Matern12Kernel', 'Matern32Kernel', 'Matern52Kernel', 'SqExponentialKernel'):
    results = period.model_true_pitch_periods(kernel_name, 32)
    print(kernel_name)
    print(results.summary())
    print('Information (bans):', results.information[-1] * log10(e))

In [None]:
from dynesty import plotting

def plot_kernel_results(kernel_name, kernel_M = 32):
    results = period.model_true_pitch_periods(kernel_name, kernel_M)

    display(kernel_name)
    display(results.summary())
    display('Information (bans)', results.information[-1] * log10(e))
    
    VARIABLE_NAMES = ['mean', 'sigma', 'scale', 'noise_sigma']
    fig, axes = dynesty.plotting.traceplot(
        results, show_titles=True,
        labels=VARIABLE_NAMES,
        verbose=True
    )
    tight_layout()

    fg, ax = dynesty.plotting.cornerplot(results, labels=VARIABLE_NAMES)
    tight_layout()
    show()

plot_kernel_results('Matern12Kernel')
plot_kernel_results('Matern32Kernel')
plot_kernel_results('Matern52Kernel')
plot_kernel_results('SqExponentialKernel')

## Fit Praat observation error

We model `p(T_praat | T_true) = N(T_true, praat_sigma² * I)`. `praat_sigma` is 0.05 or about half of `noise_sigma` above, so quite substantial correction.

In [None]:
# Takes 30 minutes
results = period.model_praat_pitch_periods()

display(results.summary())
display('Information (bans)', results.information[-1] * log10(e))

VARIABLE_NAMES = ['praat_sigma']
fig, axes = dynesty.plotting.traceplot(
    results, show_titles=True,
    labels=VARIABLE_NAMES,
    verbose=True
)
tight_layout()

fg, ax = dynesty.plotting.cornerplot(results, labels=VARIABLE_NAMES)
tight_layout()
show()

## Check the fit

In [None]:
period.fit_aplawd_z()

In [None]:
period_prior = period.marginal_prior()

samples = period_prior.sample(int(1e5), seed=jax.random.PRNGKey(50))

hist(1000/true_marginal, bins=50, alpha=.5, density=True, label='APLAWD');
hist(1000/np.asarray(samples), bins=50, alpha=.5, density=True, label='Fitted model');
legend()
xlabel('Fundamental frequency F0 (Hz)');

In [None]:
P = 30

prior = period.trajectory_prior(P)
samples = prior.sample(5, seed=jax.random.PRNGKey(randint(0,1000)))
plot(samples.T)
xlabel('Pitch period index');
ylabel('Period [msec]');
title('Period trajectory prior');

In [None]:
praat_estimate = [7.]*40

prior = period.trajectory_prior(praat_estimate=praat_estimate)
samples = prior.sample(1, seed=jax.random.PRNGKey(randint(0,1000)))
plot(praat_estimate, '--', label="given Praat estimate")
plot(samples.T, label="conditioned trajectory")
legend()
xlabel('Pitch period index');
ylabel('Period [msec]');
title('Period trajectory prior given Praat estimate');

### Measure jitter

The jitter implied by the prior is reasonably realistic. We use Praat's `jitter (local, absolute)` from <https://www.fon.hum.uva.nl/praat/manual/Voice_2__Jitter.html> because it is less sensitive to what we choose as a consistent group of pitch periods.

In [None]:
MAX_JITTER_USEC = 1000
NUM_SAMPLES = int(1e4)
NUMBINS = 50

# Measure APLAWD jitter over voiced groups
def measure_jitter(x):
    # https://www.fon.hum.uva.nl/praat/manual/PointProcess__Get_jitter__local__absolute____.html
    return float(np.mean(np.abs(np.diff(x))))

aplawd_jitter = np.array([measure_jitter(true_period) for (true_period, _) in training_pairs])
aplawd_jitter_usec = aplawd_jitter*1000
aplawd_jitter_usec = aplawd_jitter_usec[aplawd_jitter_usec < MAX_JITTER_USEC]

# Measure jitter from GP prior over one lengthscale
fit_z = period.fit_aplawd_z()
P = int(fit_z['scale']) # Measure over one lengthscale (has little influence)

prior = period.trajectory_prior(P)
samples = prior.sample(NUM_SAMPLES, seed=jax.random.PRNGKey(48702))

prior_jitter = np.mean(np.abs(np.diff(samples,axis=1)),axis=1)
prior_jitter_usec = prior_jitter*1000
prior_jitter_usec = prior_jitter_usec[prior_jitter_usec < MAX_JITTER_USEC]

# Plot
hist(aplawd_jitter_usec, bins=NUMBINS, alpha=0.5, density=True, label='APLAWD');
hist(prior_jitter_usec, bins=NUMBINS, alpha=0.5, density=True, label='prior');
axvline(83.200, label='Pathology treshold')
legend()
title('Measured jitter from APLAWD vs prior')
xlabel('Local absolute jitter (microseconds)');