# Plan after Rome

## Different parameters

- $\theta$: the parameters controlling the basisfunctions
- $\alpha$: the hyperparameters controlling the $\theta$
- `hyper`: top-level hyperparameters such as `fs`, `t` and the `data`. We need frozendicts `flax.core.frozen_dict.freeze` for this and hashable arrays.

## $p(\theta|\alpha)$

This is actually a GP because the number of pitch periods can be potentially infinite. It is just indexed by integers $(k,l)$ instead of reals $(x,x')$. And it has learnable hyperparameters, and it can be conditioned on Praat estimates. 

Because we expect $\theta$ that are further apart (measured in the number of pitch periods $|k-l|$ separating them) to be less correlated, we do use a covariance function which decays as $\exp{-|k-l|^2/2\lambda^2}$, where $\lambda \in \alpha$ is a hyperparameter which can be learned from the VTRFormants dataset (see `WRK/corpora`) and $k$ and $l$ index the pitch periods.

We only take the bandwidths and frequencies to be correlated, and their correlations can be learned with a [coregionalization kernel](https://gpflow.github.io/GPflow/develop/notebooks/advanced/coregionalisation.html#A-simple-demonstration-of-coregionalization) given observed $\theta$s from the VTRFormants dataset. Motivation:
1. We assume source-filter model, so filter (bandwidths and frequencies) uncoupled from source (pitch periods, variance, scale, open quotient).
2. Furthermore, our parametrization is also chosen such that (pitch periods, variance, scale, open quotient) are maximally independent. So we take them to be independent.
3. In addition, we do not have any empirical data to train (2.) on.
4. We infer the correlations between (bandwidths and frequencies) from F1-F3 and B1 - B3 data from VTRFormants, and use their correlations as a proxy for the correlations for (LPC) poles. Just like we used the empirical $\log F2/F1$ ratios as empirical data for the Pareto chain, which parametrizes not only formants but also poles.

Note that our $p(\theta|\alpha)$ is actually a MVN in the **log domain**, which is then further restricted to have hard bounds via `tfb.SoftClip`, which are set theoretically. See `bounded_exp()`. In the log domain our observation noise has a percentual error interpretation.

### Other ideas

- Learn the GP hyperparameters using the EGG dataset so we can very roughly fit the OQ to the EGG signal and the poles and the pitch periods using Praat already etc. to arrive at a complete training dataset to learn the $\alpha$ hyperparameters. For this we need to be able to correlate the dGF sufficiently well with the EGG signal. We can maybe use @Alku2002 to estimate OQ.
- To calibrate fractional Praat errors ($\sigma_n^2$ in log space), we can generate synthetic signals and run Praat on it and check the errors.
- Learn or set separate GP hyperparameters (mainly $\lambda$) for the "source parameters" (pitch periods, variance, scale, open quotient) (from the TIMITvoiced dataset in `WRK/corpora` -- has pitch period lengths) and for the "filter parameters" (bandwidths and frequencies) (from VTRFormants).

### General approach

- Either we learn $\alpha = \alpha^*$ from a dataset and then keep it fixed for inferring $\theta$ during real work (with MAP or maybe even with nested sampling if the $p(\theta|\alpha^*,\text{Praat estimates})$ is good enough?);
- Or we optimize $\alpha$ and $\theta$ alternatively with MAP optimization.

Both can be seen as MacKay's hyperparameter "two-stage" optimization because we already incorprate mass by integrating out the amplitudes. The second option can even be seen as "triple-stage" optimization. I think the first option is better because we already condition on the Praat estimates. 

## Philosophy

Incorporating all this prior information about smoothness etc.: we don't need smoothness heuristics to select "winning" attemps as in @Barreda2021 and Praat's heuristics -- more like @Mehta2012.

In [1]:
%run init.ipy
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd

from dgf import isokernels

key = jax.random.PRNGKey(12)



In [2]:
def bounded_exp(low, high):
    low = jnp.float64(low)
    high = jnp.float64(high)
    return tfb.Chain([tfb.SoftClip(low, high), tfb.Exp()])

bounded_exp(2., 10.).forward(log(randn(10)*2 + 5))



DeviceArray([8.58430056, 6.83282323, 2.15453249, 2.39243687, 6.23591534,
             3.37128556, 5.81306199, 5.63724106, 2.30861479, 4.26262605],            dtype=float64)

In [3]:
bijector = bounded_exp(50, 450)

kernel = isokernels.SqExponentialKernel(1., 6.)
index_points = jnp.linspace(0., 20., 4)[:,None]
observation_index_points = index_points

observations = jnp.array([200., 230., 250., 210.])
log_observations = bijector.inverse(observations)
print(log_observations)
observation_noise_variance = .1 # Praat estimate error in log space

log_gp = tfd.GaussianProcessRegressionModel(
    kernel,
    index_points,
    observation_index_points,
    log_observations,
    observation_noise_variance
)

gp = tfd.TransformedDistribution(
    log_gp,
    bijector
)

key, subkey = jax.random.split(key)
x = gp.sample(1, seed=subkey)
x, gp.log_prob(x)



[5.29831737 5.43807931 5.52146092 5.34710753]




(DeviceArray([[ 73.57689891, 192.43618943, 255.72594565, 102.02069503]], dtype=float64),
 DeviceArray([-21.52762689], dtype=float64))

In [4]:
import distrax

class MyDistribution(distrax.Distribution):
    def __init__(self):
        super().__init__()

    def _sample_n(self, key, n):
        samples = ...
        return samples

    def log_prob(self, value):
        log_prob = ...
        return log_prob

    def event_shape(self):
        event_shape = ...
        return event_shape

    def _sample_n_and_log_prob(self, key, n):
        # Optional. Only when more efficient implementation is possible.
        samples, log_prob = ...
        return samples, log_prob

key = jax.random.PRNGKey(1234)

distrax_dist = distrax.Normal(0., 1.)
wrapped_dist = distrax.to_tfp(distrax_dist)
metadist_tfp = tfd.Sample(wrapped_dist, sample_shape=[3])

samples = metadist_tfp.sample(seed=key)
print(metadist_tfp.log_prob(samples))  # Prints -3.3409896

ModuleNotFoundError: No module named 'distrax'

In [None]:
isinstance(distrax_dist, distrax.Distribution)

In [None]:
d = MyDistribution()
distrax.to_tfp(d)

In [None]:
CORRELATED_PARAMS = ('bandwidth', 'frequency')

def param_correlation(r, s, c=0.3):
    """Normalized to have a variance of one"""
    if r == s:
        return 1.

    if (r in CORRELATED_PARAMS) and (s in CORRELATED_PARAMS):
        return c # This should be a coregionalization model: a low rank (rank 3 => trained on F1-F3 and B1-B3) model extendable to 10 or more poles
    else:
        return 0.

def k_period(k, l, variance=1., lengthscale=7.):
    return variance*exp(-(k-l)**2/(2*lengthscale)) # One single lengthscale and variance describe the continuity of all the parameters

def k(x, y):
    r, k = x
    s, l = y
    return param_correlation(r, s) * k_period(k, l)


# Model covariance between two variances that are one pitch period apart
x = ('var', 0)
y = ('var', 1)
k(x, y)

In [None]:
# Model covariance between two uncorrelated params that are two pitch periods apart
x = ('T', 1)
y = ('var', 3)
k(x, y)

In [None]:
# Model covariance between two correlated params that are four pitch periods apart
x = ('bandwidth', 5)
y = ('frequency', 1)
k(x, y)