In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from bams.testsystems import *
from bams.sams_adapter import SAMSAdaptor

## Rao-Blackwellized vs binary update for Gaussian mixtures

In [None]:
# Parameters of the system and free energies
f_range = 3.0
nstates = 10

# Parameters of sams
niterations = 1000
repeats = 200
beta = 0.6

# Parameters of the sampler
nmoves = 1
save_freq = 1

### Rao-Blackwellized update

In [None]:
rb_aggregate_msd = np.zeros((repeats,niterations))
for r in range(repeats):
    f_true = np.random.uniform(low=-f_range/2.0, high=f_range/2.0, size=nstates )
    f_true -= f_true[0]
    sigmas = gen_sigmas(sigma1=1, f=f_true)

    generator = GaussianMixtureSampler(sigmas=sigmas)
    adaptor = SAMSAdaptor(nstates=nstates, beta=beta)
    
    for i in range(niterations):
        generator.sample(nmoves, save_freq)
        state = generator.state
        noisy = generator.weights
        z = -adaptor.update(state=state, noisy_observation=noisy, histogram=generator.state_counter)
        generator.zetas = z
        rb_aggregate_msd[r,i] = np.mean((f_true[1:] - z[1:])**2)

### Binary update

In [None]:
binary_aggregate_msd = np.zeros((repeats,niterations))
for r in range(repeats):
    f_true = np.random.uniform(low=-f_range/2.0, high=f_range/2.0, size=nstates )
    f_true = np.random.uniform(low=0, high=-f_range, size=nstates )
    f_true -= f_true[0]
    sigmas = gen_sigmas(sigma1=1, f=f_true)

    generator = GaussianMixtureSampler(sigmas=sigmas)
    adaptor = SAMSAdaptor(nstates=nstates, beta=beta)
    
    for i in range(niterations):
        noisy = generator.sample(nmoves, save_freq)
        state = generator.state
        z = -adaptor.update(state=state, noisy_observation=noisy, histogram=generator.state_counter)
        generator.zetas = z
        binary_aggregate_msd[r,i] = np.mean((f_true[1:] - z[1:])**2)

#### Plotting

In [None]:
# Rao-Blackwellized
alpha = 0.3

mu = np.percentile(rb_aggregate_msd, 50, axis=0)
upper = np.percentile(rb_aggregate_msd, 97.5, axis=0)
lower = np.percentile(rb_aggregate_msd, 2.5, axis=0)
t = np.arange(1, len(mu) + 1)
plt.semilogy(t, mu, lw=2, label='RB')
plt.fill_between(t, lower, upper, alpha=alpha)
# Binary
#mu = binary_aggregate_msd.mean(axis=0)
mu = np.percentile(binary_aggregate_msd,50, axis=0)
upper = np.percentile(binary_aggregate_msd, 97.5, axis=0)
lower = np.percentile(binary_aggregate_msd, 2.5, axis=0)
t = np.arange(1, len(mu) + 1)
plt.fill_between(t, lower, upper, alpha=alpha)
plt.semilogy(mu, lw=2, label='binary')

plt.title('Mean-squared error of relative free energies', fontsize=14)
plt.ylabel('Mean squared error', fontsize=12)
plt.xlabel('Iteration', fontsize=12)
plt.legend(fontsize=13)
plt.show()
