Skip to content

Commit

Permalink
Merge pull request #202 from Lnaden/master
Browse files Browse the repository at this point in the history
Ensure free energy is moderately accurate, protocol independent
  • Loading branch information
mrshirts committed Apr 4, 2016
2 parents e2d0ef0 + 6da0f40 commit a903248
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 32 deletions.
23 changes: 2 additions & 21 deletions pymbar/mbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import numpy as np
import numpy.linalg as linalg
from pymbar import mbar_solvers
from pymbar.utils import kln_to_kn, kn_to_n, ParameterError, logsumexp
from pymbar.utils import kln_to_kn, kn_to_n, ParameterError, logsumexp, check_w_normalized

DEFAULT_SOLVER_PROTOCOL = mbar_solvers.DEFAULT_SOLVER_PROTOCOL
DEFAULT_SUBSAMPLING_PROTOCOL = mbar_solvers.DEFAULT_SUBSAMPLING_PROTOCOL
Expand Down Expand Up @@ -1486,26 +1486,7 @@ def _computeAsymptoticCovarianceMatrix(self, W, N_k, method=None):
if(np.sum(N_k) != N):
raise ParameterError('W must be NxK, where N = sum_k N_k.')

# Check to make sure the weight matrix W is properly normalized.
tolerance = 1.0e-4 # tolerance for checking equality of sums

column_sums = np.sum(W, axis=0)
badcolumns = (np.abs(column_sums - 1) > tolerance)
if np.any(badcolumns):
which_badcolumns = np.arange(K)[badcolumns]
firstbad = which_badcolumns[0]
raise ParameterError(
'Warning: Should have \sum_n W_nk = 1. Actual column sum for state %d was %f. %d other columns have similar problems' %
(firstbad, column_sums[firstbad], np.sum(badcolumns)))

row_sums = np.sum(W * N_k, axis=1)
badrows = (np.abs(row_sums - 1) > tolerance)
if np.any(badrows):
which_badrows = np.arange(N)[badrows]
firstbad = which_badrows[0]
raise ParameterError(
'Warning: Should have \sum_k N_k W_nk = 1. Actual row sum for sample %d was %f. %d other rows have similar problems' %
(firstbad, row_sums[firstbad], np.sum(badrows)))
check_w_normalized(W, N_k)

# Compute estimate of asymptotic covariance matrix using specified method.
if method == 'approximate':
Expand Down
27 changes: 20 additions & 7 deletions pymbar/mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
import math
import scipy.optimize
from pymbar.utils import ensure_type, logsumexp
from pymbar.utils import ensure_type, logsumexp, check_w_normalized
import warnings

# Below are the recommended default protocols (ordered sequence of minimization algorithms / NLE solvers) for solving the MBAR equations.
# Note: we use tuples instead of lists to avoid accidental mutability.
Expand Down Expand Up @@ -313,14 +314,26 @@ def solve_mbar_once(u_kn_nonzero, N_k_nonzero, f_k_nonzero, method="hybr", tol=1
grad_and_obj = lambda x: unpad_second_arg(*mbar_objective_and_gradient(u_kn_nonzero, N_k_nonzero, pad(x))) # Objective function gradient and objective function
hess = lambda x: mbar_hessian(u_kn_nonzero, N_k_nonzero, pad(x))[1:][:, 1:] # Hessian of objective function

if method in ["L-BFGS-B", "dogleg", "CG", "BFGS", "Newton-CG", "TNC", "trust-ncg", "SLSQP"]:
if method in ["L-BFGS-B", "CG"]:
hess = None # To suppress warning from passing a hessian function.
results = scipy.optimize.minimize(grad_and_obj, f_k_nonzero[1:], jac=True, hess=hess, method=method, tol=tol, options=options)
else:
results = scipy.optimize.root(grad, f_k_nonzero[1:], jac=hess, method=method, tol=tol, options=options)
with warnings.catch_warnings(record=True) as w:
if method in ["L-BFGS-B", "dogleg", "CG", "BFGS", "Newton-CG", "TNC", "trust-ncg", "SLSQP"]:
if method in ["L-BFGS-B", "CG"]:
hess = None # To suppress warning from passing a hessian function.
results = scipy.optimize.minimize(grad_and_obj, f_k_nonzero[1:], jac=True, hess=hess, method=method, tol=tol, options=options)
else:
results = scipy.optimize.root(grad, f_k_nonzero[1:], jac=hess, method=method, tol=tol, options=options)

f_k_nonzero = pad(results["x"])

#If there were runtime warnings, show the messages
if len(w) > 0:
for warn_msg in w:
warnings.showwarning(warn_msg.message, warn_msg.category, warn_msg.filename, warn_msg.lineno, warn_msg.file, "")
#Ensure MBAR solved correctly
W_nk_check = mbar_W_nk(u_kn_nonzero, N_k_nonzero, f_k_nonzero)
check_w_normalized(W_nk_check, N_k_nonzero)
print("MBAR weights converged within tolerance, despite the SciPy Warnings. Please validate your results.")


return f_k_nonzero, results


Expand Down
40 changes: 36 additions & 4 deletions pymbar/tests/test_mbar_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,28 @@
from nose import SkipTest


def load_oscillators(n_states, n_samples):
def load_oscillators(n_states, n_samples, provide_test=False):
name = "%dx%d oscillators" % (n_states, n_samples)
O_k = np.linspace(1, 5, n_states)
k_k = np.linspace(1, 3, n_states)
N_k = (np.ones(n_states) * n_samples).astype('int')
test = pymbar.testsystems.harmonic_oscillators.HarmonicOscillatorsTestCase(O_k, k_k)
x_n, u_kn, N_k_output, s_n = test.sample(N_k, mode='u_kn')
return name, u_kn, N_k_output, s_n
returns = [name, u_kn, N_k_output, s_n]
if provide_test:
returns.append(test)
return returns

def load_exponentials(n_states, n_samples):
def load_exponentials(n_states, n_samples, provide_test=False):
name = "%dx%d exponentials" % (n_states, n_samples)
rates = np.linspace(1, 3, n_states)
N_k = (np.ones(n_states) * n_samples).astype('int')
test = pymbar.testsystems.exponential_distributions.ExponentialTestCase(rates)
x_n, u_kn, N_k_output, s_n = test.sample(N_k, mode='u_kn')
return name, u_kn, N_k_output, s_n
returns = [name, u_kn, N_k_output, s_n]
if provide_test:
returns.append(test)
return returns

def _test(data_generator):
try:
Expand Down Expand Up @@ -74,3 +80,29 @@ def test_subsampling():
mbar_sub = pymbar.MBAR(u_kn_sub, N_k_sub)
eq(mbar.f_k, mbar_sub.f_k, decimal=2)

def test_protocols():
'''Test that free energy is moderatley equal to analytical solution, independent of solver protocols'''
#Supress the warnings when jacobian and Hessian information is not used in a specific solver
import warnings
warnings.filterwarnings('ignore', '.*does not use the jacobian.*')
warnings.filterwarnings('ignore', '.*does not use Hessian.*')
from pymbar.tests.test_mbar import z_scale_factor # Importing the hacky fix to asert that free energies are moderatley correct
name, u_kn, N_k, s_n, test = load_oscillators(50, 100, provide_test=True)
fa = test.analytical_free_energies()
fa = fa[1:] - fa[0]

#scipy.optimize.minimize methods, same ones that are checked for in mbar_solvers.py
subsampling_protocols = ["L-BFGS-B", "dogleg", "CG", "BFGS", "Newton-CG", "TNC", "trust-ncg", "SLSQP"]
solver_protocols = ['hybr', 'lm'] #scipy.optimize.root methods. Omitting methods which do not use the Jacobian
for subsampling_protocol in subsampling_protocols:
for solver_protocol in solver_protocols:
#Solve MBAR with zeros for initial weights
mbar = pymbar.MBAR(u_kn, N_k, subsampling_protocol=({'method':subsampling_protocol},), solver_protocol=({'method':solver_protocol},))
#Solve MBAR with the correct f_k used for the inital weights
mbar = pymbar.MBAR(u_kn, N_k, initial_f_k=mbar.f_k, subsampling_protocol=({'method':subsampling_protocol},), solver_protocol=({'method':solver_protocol},))
fe, fe_sigma, Theta_ij = mbar.getFreeEnergyDifferences()
fe, fe_sigma = fe[0,1:], fe_sigma[0,1:]
z = (fe - fa) / fe_sigma
eq(z / z_scale_factor, np.zeros(len(z)), decimal=0)
#Clear warning filters
warnings.resetwarnings()
41 changes: 41 additions & 0 deletions pymbar/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,47 @@ def logsumexp(a, axis=None, b=None, use_numexpr=True):

return out

def check_w_normalized(W, N_k, tolerance = 1.0e-4):
"""Check the weight matrix W is properly normalized. The sum over N should be 1, and the sum over k by N_k should aslo be 1
Parameters
----------
W : np.ndarray, shape=(N, K), dtype='float'
The normalized weight matrix for snapshots and states.
W[n, k] is the weight of snapshot n in state k.
N_k : np.ndarray, shape=(K), dtype='int'
N_k[k] is the number of samples from state k.
tolerance : float, optional, default=1.0e-4
Tolerance for checking equality of sums
Returns
-------
None : NoneType
Returns a None object if test passes, otherwise raises a ParameterError with appropriate message if W is not normalized within tolerance.
"""

[N, K] = W.shape

column_sums = np.sum(W, axis=0)
badcolumns = (np.abs(column_sums - 1) > tolerance)
if np.any(badcolumns):
which_badcolumns = np.arange(K)[badcolumns]
firstbad = which_badcolumns[0]
raise ParameterError(
'Warning: Should have \sum_n W_nk = 1. Actual column sum for state %d was %f. %d other columns have similar problems' %
(firstbad, column_sums[firstbad], np.sum(badcolumns)))

row_sums = np.sum(W * N_k, axis=1)
badrows = (np.abs(row_sums - 1) > tolerance)
if np.any(badrows):
which_badrows = np.arange(N)[badrows]
firstbad = which_badrows[0]
raise ParameterError(
'Warning: Should have \sum_k N_k W_nk = 1. Actual row sum for sample %d was %f. %d other rows have similar problems' %
(firstbad, row_sums[firstbad], np.sum(badrows)))

return

#=============================================================================================
# Exception classes
#=============================================================================================
Expand Down

0 comments on commit a903248

Please sign in to comment.