# Investigations into resampling PMFs

<br>

### Uncertainty in a Sampling Distribution
Given IID samples drawn from a PMF, what is the uncertainty in the sampled PMF? Below we use the bootstrap to compute the uncertainty across a PMFs in a discrete 1-D state space. By examining the behavior, we guess at an analytic description of this uncertainty. Finally, we examine the propagation of that uncertainty into the space of log probabilities (with free energy calculations in mind).

<br>

### Blocked Bootstrapping
The blocked boostrap extends the idea of bootstrapping to a set of correlated samples (a data "signal"). In the case of free energy calculations, samples are often correlated (if e.g. generated by MC or MD). We apply a blocked-bootstrap to compute PMF uncertainties for these correlated samples.

<br>

#### Resources
- [The Uncertainty in the Estimate of the Discrete Probability](https://apt.cs.manchester.ac.uk/ftp/pub/ai/jls/CS2411/prob97/node11.html)


###### Note to self: Some of the below fails stochastically. Just rerun the failed bit!

-------------------------------

In [None]:
## imports & configuration

# standard imports

# custom imports
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
from scipy.spatial.transform import Rotation
import tqdm.notebook

# inline plots
%config InlineBackend.figure_formats = ['svg']
# %matplotlib notebook

# jupyter theme
try:
    import jupyterthemes as jt
    jt.jtplot.style()
except ImportError:
    pass

-------------------------------

## Uncertainty in a Sampling Distribution

In [None]:
# prep and seed random number generator
rng = np.random.default_rng(0)

In [None]:
## sample from a n-simplex

# sample n points from a d-dimensional simplex
def simplex(n, d):
    exps = -np.log(rng.random((n, d))) # draw from exp(-x)
    return exps/exps.sum(axis=-1)[:, None] # normalize sum to 1

# rotate simplex points and plot (should be a uniform distribution)
plt.figure(figsize=(6, 3))
r_mat = Rotation.from_euler('z', np.pi/4).as_matrix()[:2, :2]
plt.hist(r_mat.dot(simplex(100000, 2).T)[0]*np.sqrt(0.5)+0.5, bins=100, density=True)
plt.xlabel("parameter along $(x + y = 1)$")
plt.ylabel("density")
plt.tight_layout()

In [None]:
## compute boostrap-based uncertainty in discrete distribution

# configuration
n_samples = 1000 # number of samples
d = 5 # number of states

# generate samples from random distribution with d states
states = np.arange(d)
ps = simplex(1, d)[0]
samples = rng.choice(states, n_samples, p=ps)
_, counts = np.unique(samples, return_counts=True)
ps = counts/n_samples

# bootstrap to compute probability uncertainties
errors = np.array([
    scipy.stats.bootstrap(samples[None, :], lambda _samples: (_samples == state).sum()/len(_samples), vectorized=False).standard_error
    for state in states
])

# create histogram
plt.bar(states, ps, yerr=errors)
plt.xlabel("state")
plt.ylabel("density")
plt.tight_layout()

In [None]:
## method to compute bootstrap-based uncertainty estimates
def compute_uncertainty_in_distribution(n_samples, d):
    # generate samples from random distribution with d states
    states = np.arange(d)
    ps = simplex(1, d)[0]
    samples = rng.choice(states, n_samples, p=ps)
    _, counts = np.unique(samples, return_counts=True)
    ps = counts/n_samples

    # bootstrap to compute probability uncertainties
    errors = np.array([
        scipy.stats.bootstrap(samples[None, :], lambda _samples: (_samples == state).sum()/len(_samples), vectorized=False).standard_error
        for state in states
    ])

    # return sampling distribution and associated uncertainties
    return ps, errors

In [None]:
# compute errors for a variety of sample sizes
ns = np.arange(200, 1100, 100)
pses, errorses = zip(*[compute_uncertainty_in_distribution(n, 6) for n in ns])

In [None]:
## examine relationship between errors and sample size

# initialize figure
fig, axes = plt.subplots(1, 2, figsize=(10, 3))

# plot the constant (== error**2 / probability) averaged across states
scalings = (np.stack(errorses)**2/np.stack(pses))
axes[0].errorbar(ns, scalings.mean(axis=-1), yerr=scalings.std(axis=-1)/np.sqrt(scalings.shape[1]))
axes[0].scatter(ns, scalings.mean(axis=-1))
axes[0].set_xlabel("# of samples")
axes[0].set_ylabel("$\sigma(p_i)^2$ / $p_i$")

# recognize 1/n shape, plot full error relationship
axes[1].axhline(1-(1/scalings.shape[1]), color='k', ls='--', lw=0.5)
axes[1].plot(ns, scalings.mean(axis=-1)*ns, "o-")
axes[1].set_xlabel("# of samples")
axes[1].set_ylabel("$n * \sigma(p_i)^2$ / $p_i$")

# display figure
fig.tight_layout()

In [None]:
# deduce relationship between probability, # of states, and # of samples
std_err = lambda p, d, n: np.sqrt(p*(1-(1/d))/n) # !avg(ps) == 1-(1/d)

In [None]:
# compare empirical and analytic uncertainties
plt.plot(*[(std_err(0, d, n_samples), std_err(1, d, n_samples))]*2, c='k', ls='--', lw=1)
plt.scatter(errors, [std_err(p, d, n_samples) for p in ps])
plt.xlabel("empirical uncertainty")
plt.ylabel("analytic uncertainty")
plt.tight_layout()

-----------------------

In [None]:
## bootstrap the -ln(p) uncertainty

# configuration
n_samples = 10000 # number of samples
d = 10 # number of states

# generate samples from random distribution with d states
states = np.arange(d)
ps = simplex(1, d)[0]
samples = rng.choice(states, n_samples, p=ps)
_, counts = np.unique(samples, return_counts=True)
ps = counts/n_samples

# bootstrap to compute probability uncertainties
boots = np.array([
    scipy.stats.bootstrap(samples[None, :], lambda _samples: -np.log((_samples == state).sum()/len(_samples)), vectorized=False)
    for state in states
])
errors = np.array([b.standard_error for b in boots])

# plot -logPs with errors
plt.bar(states, -np.log(ps), yerr=errors)
plt.xlabel("state")
plt.ylabel("-logP")
plt.tight_layout()

In [None]:
# compare empirical and analytic -logp uncertainties
plt.plot(*[(0, std_err(ps.min(), d, n_samples)/ps.min())]*2, c='k', ls='--', lw=1)
plt.scatter(errors, np.array([std_err(p, d, n_samples) for p in ps])/ps)
plt.xlabel("empirical uncertainty")
plt.ylabel("analytic uncertainty")
plt.tight_layout()

In [None]:
# slowly compute dGs from i -> i+1 (takes like 2-3 minutes)
dG_errs = np.array([
    scipy.stats.bootstrap(
        samples[None, :],
        lambda _samples: np.diff(-np.log(np.unique(_samples, return_counts=True)[1]/len(_samples)))[i],
        vectorized=False
    ).standard_error
    for i in range(len(states)-1)
])

In [None]:
## compare empirical and analytic dG errors

# analytic dG error computation
_err = (np.array([std_err(p, d, n_samples) for p in ps])/ps)
_dG_errs = np.sqrt((_err**2)[:-1]+(_err**2)[1:])
_dG_errs *= np.sqrt(d/(d-1)) # not sure why this second factor is necessary

# compare empirical and analytic uncertainties
plt.plot(*[(0, max(dG_errs.max(), _dG_errs.max()))]*2, c='k', ls='--', lw=1)
plt.scatter(dG_errs, _dG_errs)
plt.xlabel("empirical uncertainty")
plt.ylabel("analytic uncertainty")
plt.tight_layout()

In [None]:
# method to compute i->i+1 dG error from sampling distribution (and # of samples)
def dG_err(ps, n_samples):
    d = len(ps)
    p_err = np.array([np.sqrt(p*(1-(1/d))/n_samples) for p in ps])
    logp_err = p_err/ps
    correction = np.sqrt(d/(d-1))
    dg_err = np.sqrt((logp_err**2)[:-1]+(logp_err**2)[1:])*correction
    return p_err, dg_err # returns standard error in p and -log(p_{i+1}/p_i)

-----------------------

## Blocked Bootstrap

In [None]:
# generate random samples from random distribution
n_samples, ps = 10000, simplex(1, 20)[0]
samples = rng.choice(np.arange(len(ps)), size=n_samples, p=ps)

In [None]:
# blocked bootstrap settings
n_blocks_per_signal = 1000 # THIS IS BLOCK SIZE
n_resamples = 9999 # number of resampled signals to construct
n_blocks = 1000 # should be a multiple of n_blocks_per_signal

In [None]:
# generate blocks from signal
block_len = len(samples)//n_blocks_per_signal
padded_samples = np.concatenate([
    samples, # support wrap around by adding an extra block length:
    samples[:block_len+(-len(samples)%block_len)]
])
blocks = np.concatenate([
    np.split(padded_samples[:-block_len], (len(padded_samples)-1)//block_len)[:-1]
    for i in range(n_blocks//n_blocks_per_signal) for j in rng.integers(block_len, size=1)
])
signals = np.concatenate(
    blocks[
        rng.integers(len(blocks), size=n_blocks_per_signal*n_resamples)
    ].reshape(n_blocks_per_signal, n_resamples, -1) # indexing blocks is the slow step of this cell
, axis=1)

In [None]:
# compile counts for each signal
indices, truth = np.unique(samples, return_counts=True)
counts = np.zeros((len(signals), len(indices)))
for i, signal in tqdm.notebook.tqdm(enumerate(signals), total=len(signals)):
    idx, cts = np.unique(signals[i], return_counts=True)
    counts[i][idx] = cts
probs = counts/signals.shape[1]
assert np.allclose(probs.sum(axis=-1), 1)

In [None]:
# plot mean bootstrapped PMF against true PMF
plt.figure(figsize=(5, 3))
plt.bar(indices, ps)
plt.bar(indices, probs.mean(axis=0), yerr=probs.std(axis=0), width=0.5)
plt.legend(["Generating PMF", "Mean Boot PMF"])
plt.tight_layout()

-----------------------

In [None]:
# generates PMFs from blocked bootstrap of data signal
def blocked_bootstrap_pmfs(
    data, # this is the data signal (1-D array)
    n_blocks_per_signal, # this determines block size (choose carefully)
    n_blocks = 1000, # should be a multiple of n_blocks_per_signal
    n_resamples = 9999, # number of resampled signals to construct
):
    # assess input signal values are valid indices
    indices, truth = np.unique(data, return_counts=True)
    assert np.all(indices == np.arange(len(indices))), "The input values must be indices of the states they represent!"

    # generate blocks from signal
    block_len = len(data)//n_blocks_per_signal
    padded_samples = np.concatenate([
        data, # support wrap around by adding an extra block length:
        data[:block_len+(-len(data)%block_len)]
    ])
    blocks = np.concatenate([
        np.split(padded_samples[:-block_len], (len(padded_samples)-1)//block_len)[:-1]
        for i in range(n_blocks//n_blocks_per_signal) for j in rng.integers(block_len, size=1)
    ])

    # generate resampled signals
    signals = np.concatenate(
        blocks[
            rng.integers(len(blocks), size=n_blocks_per_signal*n_resamples)
        ].reshape(n_blocks_per_signal, n_resamples, -1) # indexing blocks is the slow step of this cell
    , axis=1)

    # compile counts for each signal
    counts = np.zeros((len(signals), len(indices)))
    for i, signal in tqdm.notebook.tqdm(enumerate(signals), total=len(signals)):
        idx, cts = np.unique(signals[i], return_counts=True)
        counts[i][idx] = cts

    # return discrete distributions for each resample
    probs = counts/signals.shape[1]
    assert np.allclose(probs.sum(axis=-1), 1)
    return probs

In [None]:
# plot computed stds against analytic ones
n_samples, ps = 10000, simplex(1, 20)[0]
probs = blocked_bootstrap_pmfs(rng.choice(np.arange(len(ps)), size=n_samples, p=ps), 100)
plt.figure(figsize=(5, 3))
plt.plot(*[(0, probs.std(axis=0).max())]*2, c='k', ls='--', lw=1)
plt.scatter(probs.std(axis=0), dG_err(ps, n_samples)[0])
plt.xlabel("blocked bootstrap")
plt.ylabel("analytic")
plt.tight_layout()
print("NOTE: The error between bootstrapped and analytic is most strongly a function of block size!")

-----------------------